Source code for neer_match_utilities.model

from pathlib import Path
import pickle
from neer_match.matching_model import DLMatchingModel, NSMatchingModel
from neer_match.similarity_map import SimilarityMap
import tensorflow as tf
import typing
import shutil
import sys


[docs] class Model: """ A class for saving and loading matching models. Methods ------- save(model, target_directory, name): Save the specified model to a target directory. load(model_directory): Load a model from a given directory. """
[docs] @staticmethod def save( model: typing.Union["DLMatchingModel", "NSMatchingModel"], target_directory: Path, name: str, ) -> None: """ Save the model to a specified directory. Parameters ---------- model : DLMatchingModel or NSMatchingModel The model to be saved. target_directory : Path The directory where the model should be saved. name : str Name of the model directory. """ target_directory = Path(target_directory) / name / "model" if target_directory.exists(): replace = input( f"Directory '{target_directory}' already exists. Replace the old model? (y/n): " ).strip().lower() if replace == "y": shutil.rmtree(target_directory) print(f"Old model at '{target_directory}' has been replaced.") elif replace == "n": print("Execution halted as per user request.") sys.exit(0) else: print("Invalid input. Please type 'y' or 'n'. Aborting operation.") return target_directory.mkdir(parents=True, exist_ok=True) # --- Build composite similarity info --- # Use the original instructions stored in the SimilarityMap. # We assume model.similarity_map.instructions is a dict: { field: [metric1, metric2, ...], ... } instructions = model.similarity_map.instructions fields = list(instructions.keys()) association_sizes = model.similarity_map.association_sizes() # aggregated sizes per field composite_similarity_info = {} for i, field in enumerate(fields): agg_size = association_sizes[i] metrics = instructions[field] # list of metric names as originally provided composite_similarity_info[field] = { "metrics": metrics, "aggregated_size": agg_size, "per_metric_size": agg_size // len(metrics) } # --- Save model initialization parameters from the record pair network --- model_params = { "initial_feature_width_scales": model.record_pair_network.initial_feature_width_scales, "feature_depths": model.record_pair_network.feature_depths, "initial_record_width_scale": model.record_pair_network.initial_record_width_scale, "record_depth": model.record_pair_network.record_depth, } # Save a composite dictionary containing both similarity info and model parameters. composite_save = {"similarity_info": composite_similarity_info, "model_params": model_params} with open(target_directory / "model_info.pkl", "wb") as f: pickle.dump(composite_save, f) # --- End composite info saving --- if isinstance(model, DLMatchingModel): model.save_weights(target_directory / "model.weights.h5") if hasattr(model, "optimizer") and model.optimizer: optimizer_config = { "class_name": model.optimizer.__class__.__name__, "config": model.optimizer.get_config(), } with open(target_directory / "optimizer.pkl", "wb") as f: pickle.dump(optimizer_config, f) elif isinstance(model, NSMatchingModel): model.record_pair_network.save_weights(target_directory / "record_pair_network.weights.h5") if hasattr(model, "optimizer") and model.optimizer: optimizer_config = { "class_name": model.optimizer.__class__.__name__, "config": model.optimizer.get_config(), } with open(target_directory / "optimizer.pkl", "wb") as f: pickle.dump(optimizer_config, f) else: raise ValueError("The model must be an instance of DLMatchingModel or NSMatchingModel") print(f"Model successfully saved to {target_directory}")
[docs] @staticmethod def load(model_directory: Path) -> typing.Union[DLMatchingModel, NSMatchingModel]: """ Load a model from a specified directory. Parameters ---------- model_directory : Path The directory containing the saved model. Returns ------- DLMatchingModel or NSMatchingModel The loaded model. """ model_directory = Path(model_directory) / "model" if not model_directory.exists(): raise FileNotFoundError(f"Model directory '{model_directory}' does not exist.") # --- Load composite model info (similarity info and model parameters) --- with open(model_directory / "model_info.pkl", "rb") as f: composite_save = pickle.load(f) composite_similarity_info = composite_save["similarity_info"] model_params = composite_save["model_params"] # Reconstruct the original similarity_map as expected by DLMatchingModel: # (a plain dict mapping each field to its list of metric names) original_similarity_map = {field: info["metrics"] for field, info in composite_similarity_info.items()} # IMPORTANT: Reconstruct a SimilarityMap instance from the plain dict. similarity_map_instance = SimilarityMap(original_similarity_map) # Compute aggregated sizes in the order of fields. fields = list(composite_similarity_info.keys()) aggregated_sizes = [composite_similarity_info[field]["aggregated_size"] for field in fields] # --- End loading composite info --- if (model_directory / "model.weights.h5").exists(): # Initialize the model using the reconstructed SimilarityMap instance and stored parameters. model = DLMatchingModel( similarity_map=similarity_map_instance, initial_feature_width_scales=model_params["initial_feature_width_scales"], feature_depths=model_params["feature_depths"], initial_record_width_scale=model_params["initial_record_width_scale"], record_depth=model_params["record_depth"], ) input_shapes = [tf.TensorShape([None, s]) for s in aggregated_sizes] model.build(input_shapes=input_shapes) # --- Build dummy inputs as a list of tensors (one per field) --- # Each dummy tensor has shape (1, aggregated_size) for that field. dummy_tensors = [ tf.zeros((1, composite_similarity_info[field]["aggregated_size"])) for field in fields ] # --- End dummy inputs --- _ = model(dummy_tensors) # Forward pass to instantiate all sublayers. model.load_weights(model_directory / "model.weights.h5") if (model_directory / "optimizer.pkl").exists(): with open(model_directory / "optimizer.pkl", "rb") as f: optimizer_config = pickle.load(f) optimizer_class = getattr(tf.keras.optimizers, optimizer_config["class_name"]) model.optimizer = optimizer_class.from_config(optimizer_config["config"]) elif (model_directory / "record_pair_network.weights.h5").exists(): model = NSMatchingModel(original_similarity_map) model.compile() model.record_pair_network.load_weights(model_directory / "record_pair_network.weights.h5") if (model_directory / "optimizer.pkl").exists(): with open(model_directory / "optimizer.pkl", "rb") as f: optimizer_config = pickle.load(f) optimizer_class = getattr(tf.keras.optimizers, optimizer_config["class_name"]) model.optimizer = optimizer_class.from_config(optimizer_config["config"]) else: raise ValueError("Invalid model directory: neither DLMatchingModel nor NSMatchingModel was detected.") return model