| @@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge): | |||||
| # TODO: add abducer.mapping to the property of SimpleBridge | # TODO: add abducer.mapping to the property of SimpleBridge | ||||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | ||||
| pred_res = self.model.predict(data_samples) | |||||
| data_samples.pred_idx = pred_res["label"] | |||||
| data_samples.pred_prob = pred_res["prob"] | |||||
| return data_samples["pred_idx"], data_samples["pred_prob"] | |||||
| self.model.predict(data_samples) | |||||
| return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||||
| def abduce_pseudo_label( | def abduce_pseudo_label( | ||||
| self, | self, | ||||
| @@ -1,3 +1,4 @@ | |||||
| from .bridge_dataset import BridgeDataset | from .bridge_dataset import BridgeDataset | ||||
| from .classification_dataset import ClassificationDataset | from .classification_dataset import ClassificationDataset | ||||
| from .regression_dataset import RegressionDataset | |||||
| from .prediction_dataset import PredictionDataset | |||||
| from .regression_dataset import RegressionDataset | |||||
| @@ -0,0 +1,56 @@ | |||||
| from typing import Any, Callable, List, Tuple | |||||
| import torch | |||||
| from torch.utils.data import Dataset | |||||
| class PredictionDataset(Dataset): | |||||
| def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||||
| """ | |||||
| Initialize the dataset used for classification task. | |||||
| Parameters | |||||
| ---------- | |||||
| X : List[Any] | |||||
| The input data. | |||||
| transform : Callable[..., Any], optional | |||||
| A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||||
| """ | |||||
| if not isinstance(X, list): | |||||
| raise ValueError("X should be of type list.") | |||||
| self.X = X | |||||
| self.transform = transform | |||||
| def __len__(self) -> int: | |||||
| """ | |||||
| Return the length of the dataset. | |||||
| Returns | |||||
| ------- | |||||
| int | |||||
| The length of the dataset. | |||||
| """ | |||||
| return len(self.X) | |||||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||||
| """ | |||||
| Get the item at the given index. | |||||
| Parameters | |||||
| ---------- | |||||
| index : int | |||||
| The index of the item to get. | |||||
| Returns | |||||
| ------- | |||||
| Tuple[Any, torch.Tensor] | |||||
| A tuple containing the object and its label. | |||||
| """ | |||||
| if index >= len(self): | |||||
| raise ValueError("index range error") | |||||
| x = self.X[index] | |||||
| if self.transform is not None: | |||||
| x = self.transform(x) | |||||
| return x | |||||
| @@ -71,11 +71,14 @@ class ABLModel: | |||||
| label = prob.argmax(axis=1) | label = prob.argmax(axis=1) | ||||
| prob = reform_idx(prob, data_samples["X"]) | prob = reform_idx(prob, data_samples["X"]) | ||||
| else: | else: | ||||
| prob = [None] * len(data_samples) | |||||
| prob = None | |||||
| label = model.predict(X=data_X) | label = model.predict(X=data_X) | ||||
| label = reform_idx(label, data_samples["X"]) | label = reform_idx(label, data_samples["X"]) | ||||
| data_samples.pred_idx = label | |||||
| if prob is not None: | |||||
| data_samples.pred_prob = prob | |||||
| return {"label": label, "prob": prob} | return {"label": label, "prob": prob} | ||||
| def train(self, data_samples: ListData) -> float: | def train(self, data_samples: ListData) -> float: | ||||
| @@ -11,13 +11,14 @@ | |||||
| # ================================================================# | # ================================================================# | ||||
| import os | import os | ||||
| import logging | |||||
| from typing import Any, Callable, List, Optional, T, Tuple | from typing import Any, Callable, List, Optional, T, Tuple | ||||
| import numpy | import numpy | ||||
| import torch | import torch | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from ..dataset import ClassificationDataset | |||||
| from ..dataset import ClassificationDataset, PredictionDataset | |||||
| from ..utils.logger import print_log | from ..utils.logger import print_log | ||||
| @@ -197,7 +198,12 @@ class BasicNN: | |||||
| return torch.cat(results, axis=0) | return torch.cat(results, axis=0) | ||||
| def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||||
| def predict( | |||||
| self, | |||||
| data_loader: DataLoader = None, | |||||
| X: List[Any] = None, | |||||
| test_transform: Callable[..., Any] = None, | |||||
| ) -> numpy.ndarray: | |||||
| """ | """ | ||||
| Predict the class of the input data. | Predict the class of the input data. | ||||
| @@ -215,12 +221,29 @@ class BasicNN: | |||||
| """ | """ | ||||
| if data_loader is None: | if data_loader is None: | ||||
| if self.transform is not None: | |||||
| X = [self.transform(x) for x in X] | |||||
| data_loader = DataLoader(X, batch_size=self.batch_size) | |||||
| if test_transform is None: | |||||
| print_log( | |||||
| "Transform used in the training phase will be used in prediction.", | |||||
| "current", | |||||
| level=logging.WARNING, | |||||
| ) | |||||
| dataset = PredictionDataset(X, self.transform) | |||||
| else: | |||||
| dataset = PredictionDataset(X, test_transform) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=self.batch_size, | |||||
| num_workers=int(self.num_workers), | |||||
| collate_fn=self.collate_fn, | |||||
| ) | |||||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | return self._predict(data_loader).argmax(axis=1).cpu().numpy() | ||||
| def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||||
| def predict_proba( | |||||
| self, | |||||
| data_loader: DataLoader = None, | |||||
| X: List[Any] = None, | |||||
| test_transform: Callable[..., Any] = None, | |||||
| ) -> numpy.ndarray: | |||||
| """ | """ | ||||
| Predict the probability of each class for the input data. | Predict the probability of each class for the input data. | ||||
| @@ -238,9 +261,21 @@ class BasicNN: | |||||
| """ | """ | ||||
| if data_loader is None: | if data_loader is None: | ||||
| if self.transform is not None: | |||||
| X = [self.transform(x) for x in X] | |||||
| data_loader = DataLoader(X, batch_size=self.batch_size) | |||||
| if test_transform is None: | |||||
| print_log( | |||||
| "Transform used in the training phase will be used in prediction.", | |||||
| "current", | |||||
| level=logging.WARNING, | |||||
| ) | |||||
| dataset = PredictionDataset(X, self.transform) | |||||
| else: | |||||
| dataset = PredictionDataset(X, test_transform) | |||||
| data_loader = DataLoader( | |||||
| dataset, | |||||
| batch_size=self.batch_size, | |||||
| num_workers=int(self.num_workers), | |||||
| collate_fn=self.collate_fn, | |||||
| ) | |||||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | return self._predict(data_loader).softmax(axis=1).cpu().numpy() | ||||
| def _score(self, data_loader) -> Tuple[float, float]: | def _score(self, data_loader) -> Tuple[float, float]: | ||||
| @@ -1,8 +1,7 @@ | |||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Any, Hashable, List | from typing import Any, Hashable, List | ||||
| from abl.structures import ListData | |||||
| from ..structures import ListData | |||||
| from .base_kb import BaseKB | from .base_kb import BaseKB | ||||