From 1830daffd4a6a7f0cb885326a2ebab7d00108e8f Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 11 Jun 2023 16:39:01 +0800 Subject: [PATCH] [ENH] add save and load method to ABLModel --- abl/learning/abl_model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 70d1888..df83238 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -101,6 +101,20 @@ class ABLModel: data_X, _ = self.merge_data(X) data_Y, _ = self.merge_data(Y) return self.classifier_list[0].fit(X=data_X, y=data_Y) + + def save(self, *args, **kwargs) -> None: + _model = self.classifier_list[0] + if hasattr(_model, "save"): + self._model.save(*args, **kwargs) + else: + raise NotImplementedError(f"{type(_model).__name__} object dosen't have the save method") + + def load(self, *args, **kwargs): + _model = self.classifier_list[0] + if hasattr(_model, "load"): + _model.load(*args, **kwargs) + else: + raise NotImplementedError(f"{type(_model).__name__} object dosen't have the load method") @staticmethod def merge_data(X):