diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 1a720ec..4c6fbad 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -45,14 +45,8 @@ class ABLModel: """ def __init__(self, base_model: Any) -> None: - if not ( - hasattr(base_model, "fit") - and hasattr(base_model, "predict") - and hasattr(base_model, "score") - ): - raise NotImplementedError( - "base_model should have fit, predict and score methods." - ) + if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): + raise NotImplementedError("The base_model should implement fit and predict methods.") self.base_model = base_model @@ -84,27 +78,6 @@ class ABLModel: return {"label": label, "prob": prob} - def valid(self, data_samples: ListData) -> float: - """ - Calculate the accuracy for the given data. - - Parameters - ---------- - X : List[List[Any]] - The data to calculate the accuracy on. - Y : List[Any] - The true labels for the given data. - - Returns - ------- - float - The accuracy score for the given data. - """ - data_X = data_samples.flatten("X") - data_y = data_samples.flatten("gt_idx") - score = self.base_model.score(X=data_X, y=data_y) - return score - def train(self, data_samples: ListData) -> float: """ Train the model on the given data. @@ -131,19 +104,20 @@ class ABLModel: method = getattr(model, operation) method(*args, **kwargs) else: - try: - if not f"{operation}_path" in kwargs.keys(): - raise ValueError(f"'{operation}_path' should not be None") - if operation == "save": - with open(kwargs["save_path"], "wb") as file: - pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) - elif operation == "load": - with open(kwargs["load_path"], "rb") as file: - self.base_model = pickle.load(file) - except: - raise NotImplementedError( - f"{type(model).__name__} object doesn't have the {operation} method" - ) + if not f"{operation}_path" in kwargs.keys(): + raise ValueError(f"'{operation}_path' should not be None") + else: + try: + if operation == "save": + with open(kwargs["save_path"], "wb") as file: + pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) + elif operation == "load": + with open(kwargs["load_path"], "rb") as file: + self.base_model = pickle.load(file) + except: + raise NotImplementedError( + f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." + ) def save(self, *args, **kwargs) -> None: """