Browse Source

remove mapping in ABLModel

pull/3/head
Gao Enhao 2 years ago
parent
commit
dc24a1bc8a
1 changed files with 42 additions and 54 deletions
  1. +42
    -54
      abl/learning/abl_model.py

+ 42
- 54
abl/learning/abl_model.py View File

@@ -10,29 +10,13 @@
#
# ================================================================#
from itertools import chain
from typing import List, Any
from typing import List, Any, Optional


def get_part_data(X, i):
return list(map(lambda x: x[i], X))


def merge_data(X):
ret_mark = list(map(lambda x: len(x), X))
ret_X = list(chain(*X))
return ret_X, ret_mark


def reshape_data(Y, marks):
begin_mark = 0
ret_Y = []
for mark in marks:
end_mark = begin_mark + mark
ret_Y.append(Y[begin_mark:end_mark])
begin_mark = end_mark
return ret_Y


class ABLModel:
"""
Serialize data and provide a unified interface for different machine learning models.
@@ -41,42 +25,29 @@ class ABLModel:
----------
base_model : Machine Learning Model
The base model to use for training and prediction.
pseudo_label_list : List[Any]
A list of pseudo labels to use for training.

Attributes
----------
cls_list : List[Any]
classifier_list : List[Any]
A list of classifiers.
pseudo_label_list : List[Any]
A list of pseudo labels to use for training.
mapping : dict
A dictionary mapping pseudo labels to integers.
remapping : dict
A dictionary mapping integers to pseudo labels.

Methods
-------
predict(X: List[List[Any]]) -> dict
Predict the class labels and probabilities for the given data.
predict(X: List[List[Any]], mapping: Optional[dict]) -> dict
Predict the labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
train(X: List[List[Any]], Y: List[Any])
Train the model on the given data.
"""
def __init__(self, base_model, pseudo_label_list: List[Any]) -> None:
self.cls_list = []
self.cls_list.append(base_model)

self.pseudo_label_list = pseudo_label_list
self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list)))))
self.remapping = dict(
zip(list(range(len(pseudo_label_list))), pseudo_label_list)
)
def __init__(self, base_model) -> None:
self.classifier_list = []
self.classifier_list.append(base_model)

def predict(self, X: List[List[Any]]) -> dict:
def predict(self, X: List[List[Any]], mapping: Optional[dict]) -> dict:
"""
Predict the class labels and probabilities for the given data.
Predict the labels and probabilities for the given data.

Parameters
----------
@@ -86,17 +57,18 @@ class ABLModel:
Returns
-------
dict
A dictionary containing the predicted class labels and probabilities.
A dictionary containing the predicted labels and probabilities.
"""
data_X, marks = merge_data(X)
prob = self.cls_list[0].predict_proba(X=data_X)
_cls = prob.argmax(axis=1)
cls = list(map(lambda x: self.remapping[x], _cls))
data_X, marks = self.merge_data(X)
prob = self.classifier_list[0].predict_proba(X=data_X)
label = prob.argmax(axis=1)
if mapping is not None:
label = [mapping[x] for x in label]

prob = reshape_data(prob, marks)
cls = reshape_data(cls, marks)
prob = self.reshape_data(prob, marks)
label = self.reshape_data(label, marks)

return {"cls": cls, "prob": prob}
return {"label": label, "prob": prob}

def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
"""
@@ -107,17 +79,17 @@ class ABLModel:
X : List[List[Any]]
The data to calculate the accuracy on.
Y : List[Any]
The true class labels for the given data.
The true labels for the given data.

Returns
-------
float
The accuracy score for the given data.
"""
data_X, _ = merge_data(X)
_data_Y, _ = merge_data(Y)
data_X, _ = self.merge_data(X)
_data_Y, _ = self.merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
score = self.cls_list[0].score(X=data_X, y=data_Y)
score = self.classifier_list[0].score(X=data_X, y=data_Y)
return score

def train(self, X: List[List[Any]], Y: List[Any]):
@@ -129,9 +101,25 @@ class ABLModel:
X : List[List[Any]]
The data to train on.
Y : List[Any]
The true class labels for the given data.
The true labels for the given data.
"""
data_X, _ = merge_data(X)
_data_Y, _ = merge_data(Y)
data_X, _ = self.merge_data(X)
_data_Y, _ = self.merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
self.cls_list[0].fit(X=data_X, y=data_Y)
self.classifier_list[0].fit(X=data_X, y=data_Y)

@staticmethod
def merge_data(X):
ret_mark = list(map(lambda x: len(x), X))
ret_X = list(chain(*X))
return ret_X, ret_mark

@staticmethod
def reshape_data(Y, marks):
begin_mark = 0
ret_Y = []
for mark in marks:
end_mark = begin_mark + mark
ret_Y.append(Y[begin_mark:end_mark])
begin_mark = end_mark
return ret_Y

Loading…
Cancel
Save