diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index f20939c..6f62509 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -10,29 +10,13 @@ # # ================================================================# from itertools import chain -from typing import List, Any +from typing import List, Any, Optional def get_part_data(X, i): return list(map(lambda x: x[i], X)) -def merge_data(X): - ret_mark = list(map(lambda x: len(x), X)) - ret_X = list(chain(*X)) - return ret_X, ret_mark - - -def reshape_data(Y, marks): - begin_mark = 0 - ret_Y = [] - for mark in marks: - end_mark = begin_mark + mark - ret_Y.append(Y[begin_mark:end_mark]) - begin_mark = end_mark - return ret_Y - - class ABLModel: """ Serialize data and provide a unified interface for different machine learning models. @@ -41,42 +25,29 @@ class ABLModel: ---------- base_model : Machine Learning Model The base model to use for training and prediction. - pseudo_label_list : List[Any] - A list of pseudo labels to use for training. Attributes ---------- - cls_list : List[Any] + classifier_list : List[Any] A list of classifiers. - pseudo_label_list : List[Any] - A list of pseudo labels to use for training. - mapping : dict - A dictionary mapping pseudo labels to integers. - remapping : dict - A dictionary mapping integers to pseudo labels. Methods ------- - predict(X: List[List[Any]]) -> dict - Predict the class labels and probabilities for the given data. + predict(X: List[List[Any]], mapping: Optional[dict]) -> dict + Predict the labels and probabilities for the given data. valid(X: List[List[Any]], Y: List[Any]) -> float Calculate the accuracy score for the given data. train(X: List[List[Any]], Y: List[Any]) Train the model on the given data. """ - def __init__(self, base_model, pseudo_label_list: List[Any]) -> None: - self.cls_list = [] - self.cls_list.append(base_model) - self.pseudo_label_list = pseudo_label_list - self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list))))) - self.remapping = dict( - zip(list(range(len(pseudo_label_list))), pseudo_label_list) - ) + def __init__(self, base_model) -> None: + self.classifier_list = [] + self.classifier_list.append(base_model) - def predict(self, X: List[List[Any]]) -> dict: + def predict(self, X: List[List[Any]], mapping: Optional[dict]) -> dict: """ - Predict the class labels and probabilities for the given data. + Predict the labels and probabilities for the given data. Parameters ---------- @@ -86,17 +57,18 @@ class ABLModel: Returns ------- dict - A dictionary containing the predicted class labels and probabilities. + A dictionary containing the predicted labels and probabilities. """ - data_X, marks = merge_data(X) - prob = self.cls_list[0].predict_proba(X=data_X) - _cls = prob.argmax(axis=1) - cls = list(map(lambda x: self.remapping[x], _cls)) + data_X, marks = self.merge_data(X) + prob = self.classifier_list[0].predict_proba(X=data_X) + label = prob.argmax(axis=1) + if mapping is not None: + label = [mapping[x] for x in label] - prob = reshape_data(prob, marks) - cls = reshape_data(cls, marks) + prob = self.reshape_data(prob, marks) + label = self.reshape_data(label, marks) - return {"cls": cls, "prob": prob} + return {"label": label, "prob": prob} def valid(self, X: List[List[Any]], Y: List[Any]) -> float: """ @@ -107,17 +79,17 @@ class ABLModel: X : List[List[Any]] The data to calculate the accuracy on. Y : List[Any] - The true class labels for the given data. + The true labels for the given data. Returns ------- float The accuracy score for the given data. """ - data_X, _ = merge_data(X) - _data_Y, _ = merge_data(Y) + data_X, _ = self.merge_data(X) + _data_Y, _ = self.merge_data(Y) data_Y = list(map(lambda y: self.mapping[y], _data_Y)) - score = self.cls_list[0].score(X=data_X, y=data_Y) + score = self.classifier_list[0].score(X=data_X, y=data_Y) return score def train(self, X: List[List[Any]], Y: List[Any]): @@ -129,9 +101,25 @@ class ABLModel: X : List[List[Any]] The data to train on. Y : List[Any] - The true class labels for the given data. + The true labels for the given data. """ - data_X, _ = merge_data(X) - _data_Y, _ = merge_data(Y) + data_X, _ = self.merge_data(X) + _data_Y, _ = self.merge_data(Y) data_Y = list(map(lambda y: self.mapping[y], _data_Y)) - self.cls_list[0].fit(X=data_X, y=data_Y) + self.classifier_list[0].fit(X=data_X, y=data_Y) + + @staticmethod + def merge_data(X): + ret_mark = list(map(lambda x: len(x), X)) + ret_X = list(chain(*X)) + return ret_X, ret_mark + + @staticmethod + def reshape_data(Y, marks): + begin_mark = 0 + ret_Y = [] + for mark in marks: + end_mark = begin_mark + mark + ret_Y.append(Y[begin_mark:end_mark]) + begin_mark = end_mark + return ret_Y