| @@ -45,14 +45,8 @@ class ABLModel: | |||
| """ | |||
| def __init__(self, base_model: Any) -> None: | |||
| if not ( | |||
| hasattr(base_model, "fit") | |||
| and hasattr(base_model, "predict") | |||
| and hasattr(base_model, "score") | |||
| ): | |||
| raise NotImplementedError( | |||
| "base_model should have fit, predict and score methods." | |||
| ) | |||
| if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): | |||
| raise NotImplementedError("The base_model should implement fit and predict methods.") | |||
| self.base_model = base_model | |||
| @@ -84,27 +78,6 @@ class ABLModel: | |||
| return {"label": label, "prob": prob} | |||
| def valid(self, data_samples: ListData) -> float: | |||
| """ | |||
| Calculate the accuracy for the given data. | |||
| Parameters | |||
| ---------- | |||
| X : List[List[Any]] | |||
| The data to calculate the accuracy on. | |||
| Y : List[Any] | |||
| The true labels for the given data. | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy score for the given data. | |||
| """ | |||
| data_X = data_samples.flatten("X") | |||
| data_y = data_samples.flatten("gt_idx") | |||
| score = self.base_model.score(X=data_X, y=data_y) | |||
| return score | |||
| def train(self, data_samples: ListData) -> float: | |||
| """ | |||
| Train the model on the given data. | |||
| @@ -131,19 +104,20 @@ class ABLModel: | |||
| method = getattr(model, operation) | |||
| method(*args, **kwargs) | |||
| else: | |||
| try: | |||
| if not f"{operation}_path" in kwargs.keys(): | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except: | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method" | |||
| ) | |||
| if not f"{operation}_path" in kwargs.keys(): | |||
| raise ValueError(f"'{operation}_path' should not be None") | |||
| else: | |||
| try: | |||
| if operation == "save": | |||
| with open(kwargs["save_path"], "wb") as file: | |||
| pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) | |||
| elif operation == "load": | |||
| with open(kwargs["load_path"], "rb") as file: | |||
| self.base_model = pickle.load(file) | |||
| except: | |||
| raise NotImplementedError( | |||
| f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." | |||
| ) | |||
| def save(self, *args, **kwargs) -> None: | |||
| """ | |||