| @@ -24,10 +24,8 @@ class SimpleBridge(BaseBridge): | |||
| # TODO: add abducer.mapping to the property of SimpleBridge | |||
| def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
| pred_res = self.model.predict(data_samples) | |||
| data_samples.pred_idx = pred_res["label"] | |||
| data_samples.pred_prob = pred_res["prob"] | |||
| return data_samples["pred_idx"], data_samples["pred_prob"] | |||
| self.model.predict(data_samples) | |||
| return data_samples["pred_idx"], data_samples.get("pred_prob", None) | |||
| def abduce_pseudo_label( | |||
| self, | |||
| @@ -1,3 +1,4 @@ | |||
| from .bridge_dataset import BridgeDataset | |||
| from .classification_dataset import ClassificationDataset | |||
| from .regression_dataset import RegressionDataset | |||
| from .prediction_dataset import PredictionDataset | |||
| from .regression_dataset import RegressionDataset | |||
| @@ -0,0 +1,56 @@ | |||
| from typing import Any, Callable, List, Tuple | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| class PredictionDataset(Dataset): | |||
| def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
| """ | |||
| Initialize the dataset used for classification task. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| The input data. | |||
| transform : Callable[..., Any], optional | |||
| A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
| """ | |||
| if not isinstance(X, list): | |||
| raise ValueError("X should be of type list.") | |||
| self.X = X | |||
| self.transform = transform | |||
| def __len__(self) -> int: | |||
| """ | |||
| Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
| """ | |||
| Get the item at the given index. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to get. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, torch.Tensor] | |||
| A tuple containing the object and its label. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| if self.transform is not None: | |||
| x = self.transform(x) | |||
| return x | |||
| @@ -71,11 +71,14 @@ class ABLModel: | |||
| label = prob.argmax(axis=1) | |||
| prob = reform_idx(prob, data_samples["X"]) | |||
| else: | |||
| prob = [None] * len(data_samples) | |||
| prob = None | |||
| label = model.predict(X=data_X) | |||
| label = reform_idx(label, data_samples["X"]) | |||
| data_samples.pred_idx = label | |||
| if prob is not None: | |||
| data_samples.pred_prob = prob | |||
| return {"label": label, "prob": prob} | |||
| def train(self, data_samples: ListData) -> float: | |||
| @@ -11,13 +11,14 @@ | |||
| # ================================================================# | |||
| import os | |||
| import logging | |||
| from typing import Any, Callable, List, Optional, T, Tuple | |||
| import numpy | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| from ..dataset import ClassificationDataset | |||
| from ..dataset import ClassificationDataset, PredictionDataset | |||
| from ..utils.logger import print_log | |||
| @@ -197,7 +198,12 @@ class BasicNN: | |||
| return torch.cat(results, axis=0) | |||
| def predict(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
| def predict( | |||
| self, | |||
| data_loader: DataLoader = None, | |||
| X: List[Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| Predict the class of the input data. | |||
| @@ -215,12 +221,29 @@ class BasicNN: | |||
| """ | |||
| if data_loader is None: | |||
| if self.transform is not None: | |||
| X = [self.transform(x) for x in X] | |||
| data_loader = DataLoader(X, batch_size=self.batch_size) | |||
| if test_transform is None: | |||
| print_log( | |||
| "Transform used in the training phase will be used in prediction.", | |||
| "current", | |||
| level=logging.WARNING, | |||
| ) | |||
| dataset = PredictionDataset(X, self.transform) | |||
| else: | |||
| dataset = PredictionDataset(X, test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
| def predict_proba(self, data_loader: DataLoader = None, X: List[Any] = None) -> numpy.ndarray: | |||
| def predict_proba( | |||
| self, | |||
| data_loader: DataLoader = None, | |||
| X: List[Any] = None, | |||
| test_transform: Callable[..., Any] = None, | |||
| ) -> numpy.ndarray: | |||
| """ | |||
| Predict the probability of each class for the input data. | |||
| @@ -238,9 +261,21 @@ class BasicNN: | |||
| """ | |||
| if data_loader is None: | |||
| if self.transform is not None: | |||
| X = [self.transform(x) for x in X] | |||
| data_loader = DataLoader(X, batch_size=self.batch_size) | |||
| if test_transform is None: | |||
| print_log( | |||
| "Transform used in the training phase will be used in prediction.", | |||
| "current", | |||
| level=logging.WARNING, | |||
| ) | |||
| dataset = PredictionDataset(X, self.transform) | |||
| else: | |||
| dataset = PredictionDataset(X, test_transform) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=self.collate_fn, | |||
| ) | |||
| return self._predict(data_loader).softmax(axis=1).cpu().numpy() | |||
| def _score(self, data_loader) -> Tuple[float, float]: | |||
| @@ -1,8 +1,7 @@ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Hashable, List | |||
| from abl.structures import ListData | |||
| from ..structures import ListData | |||
| from .base_kb import BaseKB | |||