|
|
|
@@ -101,6 +101,20 @@ class ABLModel: |
|
|
|
data_X, _ = self.merge_data(X) |
|
|
|
data_Y, _ = self.merge_data(Y) |
|
|
|
return self.classifier_list[0].fit(X=data_X, y=data_Y) |
|
|
|
|
|
|
|
def save(self, *args, **kwargs) -> None: |
|
|
|
_model = self.classifier_list[0] |
|
|
|
if hasattr(_model, "save"): |
|
|
|
self._model.save(*args, **kwargs) |
|
|
|
else: |
|
|
|
raise NotImplementedError(f"{type(_model).__name__} object dosen't have the save method") |
|
|
|
|
|
|
|
def load(self, *args, **kwargs): |
|
|
|
_model = self.classifier_list[0] |
|
|
|
if hasattr(_model, "load"): |
|
|
|
_model.load(*args, **kwargs) |
|
|
|
else: |
|
|
|
raise NotImplementedError(f"{type(_model).__name__} object dosen't have the load method") |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def merge_data(X): |
|
|
|
|