diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 8f04a60..83009b1 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -91,7 +91,7 @@ class ABLModel: score = self.classifier_list[0].score(X=data_X, y=data_Y) return score - def train(self, X: List[List[Any]], Y: List[Any]): + def train(self, X: List[List[Any]], Y: List[Any]) -> float: """ Train the model on the given data. @@ -104,7 +104,7 @@ class ABLModel: """ data_X, _ = self.merge_data(X) data_Y, _ = self.merge_data(Y) - self.classifier_list[0].fit(X=data_X, y=data_Y) + return self.classifier_list[0].fit(X=data_X, y=data_Y) @staticmethod def merge_data(X):