diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 70d1888..df83238 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -101,6 +101,20 @@ class ABLModel: data_X, _ = self.merge_data(X) data_Y, _ = self.merge_data(Y) return self.classifier_list[0].fit(X=data_X, y=data_Y) + + def save(self, *args, **kwargs) -> None: + _model = self.classifier_list[0] + if hasattr(_model, "save"): + self._model.save(*args, **kwargs) + else: + raise NotImplementedError(f"{type(_model).__name__} object dosen't have the save method") + + def load(self, *args, **kwargs): + _model = self.classifier_list[0] + if hasattr(_model, "load"): + _model.load(*args, **kwargs) + else: + raise NotImplementedError(f"{type(_model).__name__} object dosen't have the load method") @staticmethod def merge_data(X):