| @@ -1 +1,3 @@ | |||
| from .base_dataset import BaseDataset | |||
| from .bridge_dataset import BridgeDataset | |||
| from .classification_dataset import ClassificationDataset | |||
| from .regression_dataset import RegressionDataset | |||
| @@ -0,0 +1,55 @@ | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple | |||
| class BridgeDataset(Dataset): | |||
| def __init__(self, X: List[Any], Z: List[Any], Y: List[Any]): | |||
| """Initialize a basic dataset. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| A list of objects representing the input data. | |||
| Z : List[Any] | |||
| A list of objects representing the symbol. | |||
| Y : List[Any] | |||
| A list of objects representing the label. | |||
| """ | |||
| self.X = X | |||
| self.Z = Z | |||
| self.Y = Y | |||
| if self.Z is None: | |||
| self.Z = [None] * len(self.X) | |||
| def __len__(self): | |||
| """Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |||
| """Get an item from the dataset. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to retrieve. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, Any] | |||
| A tuple containing the input and output data at the specified index. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| X = self.X[index] | |||
| Z = self.Z[index] | |||
| Y = self.Y[index] | |||
| return (X, Z, Y) | |||
| @@ -0,0 +1,60 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple, Callable | |||
| class ClassificationDataset(Dataset): | |||
| def __init__( | |||
| self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None | |||
| ): | |||
| """ | |||
| Initialize the dataset used for classification task. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| The input data. | |||
| Y : List[int] | |||
| The target data. | |||
| transform : Callable[..., Any], optional | |||
| A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
| """ | |||
| self.X = X | |||
| self.Y = torch.LongTensor(Y) | |||
| self.transform = transform | |||
| def __len__(self) -> int: | |||
| """ | |||
| Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: | |||
| """ | |||
| Get the item at the given index. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to get. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, torch.Tensor] | |||
| A tuple containing the object and its label. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| if self.transform is not None: | |||
| x = self.transform(x) | |||
| y = self.Y[index] | |||
| return x, y | |||
| @@ -0,0 +1,49 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| from typing import List, Any, Tuple | |||
| class RegressionDataset(Dataset): | |||
| def __init__(self, X: List[Any], Y: List[Any]): | |||
| """Initialize a basic dataset. | |||
| Parameters | |||
| ---------- | |||
| X : List[Any] | |||
| A list of objects representing the input data. | |||
| Y : List[Any] | |||
| A list of objects representing the output data. | |||
| """ | |||
| self.X = X | |||
| self.Y = Y | |||
| def __len__(self): | |||
| """Return the length of the dataset. | |||
| Returns | |||
| ------- | |||
| int | |||
| The length of the dataset. | |||
| """ | |||
| return len(self.X) | |||
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |||
| """Get an item from the dataset. | |||
| Parameters | |||
| ---------- | |||
| index : int | |||
| The index of the item to retrieve. | |||
| Returns | |||
| ------- | |||
| Tuple[Any, Any] | |||
| A tuple containing the input and output data at the specified index. | |||
| """ | |||
| if index >= len(self): | |||
| raise ValueError("index range error") | |||
| x = self.X[index] | |||
| y = self.Y[index] | |||
| return x, y | |||