# coding: utf-8 # ================================================================# # Copyright (C) 2020 Freecss All rights reserved. # # File Name :models.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/04/02 # Description : # # ================================================================# import pickle from typing import Any, Dict from ..structures import ListData from ..utils import reform_list class ABLModel: """ Serialize data and provide a unified interface for different machine learning models. Parameters ---------- base_model : Machine Learning Model The base model to use for training and prediction. """ def __init__(self, base_model: Any) -> None: if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): raise NotImplementedError("The base_model should implement fit and predict methods.") self.base_model = base_model def predict(self, data_samples: ListData) -> Dict: """ Predict the labels and probabilities for the given data. Parameters ---------- data_samples : ListData A batch of data to predict on. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ model = self.base_model data_X = data_samples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) prob = reform_list(prob, data_samples.X) else: prob = None label = model.predict(X=data_X) label = reform_list(label, data_samples.X) data_samples.pred_idx = label data_samples.pred_prob = prob return {"label": label, "prob": prob} def train(self, data_samples: ListData) -> float: """ Train the model on the given data. Parameters ---------- data_samples : ListData A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. Returns ------- float The loss value of the trained model. """ data_X = data_samples.flatten("X") data_y = data_samples.flatten("abduced_idx") return self.base_model.fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) else: if not f"{operation}_path" in kwargs.keys(): raise ValueError(f"'{operation}_path' should not be None") else: try: if operation == "save": with open(kwargs["save_path"], "wb") as file: pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) elif operation == "load": with open(kwargs["load_path"], "rb") as file: self.base_model = pickle.load(file) except: raise NotImplementedError( f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." ) def save(self, *args, **kwargs) -> None: """ Save the model to a file. """ self._model_operation("save", *args, **kwargs) def load(self, *args, **kwargs) -> None: """ Load the model from a file. """ self._model_operation("load", *args, **kwargs)