From 74151b0d83af403b814a288ee4f5ed100d7caf7e Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 9 Jun 2023 19:14:27 +0800 Subject: [PATCH] [MNT] modify typing in ABLModel --- abl/learning/abl_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index 8f04a60..83009b1 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -91,7 +91,7 @@ class ABLModel: score = self.classifier_list[0].score(X=data_X, y=data_Y) return score - def train(self, X: List[List[Any]], Y: List[Any]): + def train(self, X: List[List[Any]], Y: List[Any]) -> float: """ Train the model on the given data. @@ -104,7 +104,7 @@ class ABLModel: """ data_X, _ = self.merge_data(X) data_Y, _ = self.merge_data(Y) - self.classifier_list[0].fit(X=data_X, y=data_Y) + return self.classifier_list[0].fit(X=data_X, y=data_Y) @staticmethod def merge_data(X):