diff --git a/abl/dataset/__init__.py b/abl/dataset/__init__.py index 267408a..6be0df1 100644 --- a/abl/dataset/__init__.py +++ b/abl/dataset/__init__.py @@ -1 +1,3 @@ -from .base_dataset import BaseDataset \ No newline at end of file +from .bridge_dataset import BridgeDataset +from .classification_dataset import ClassificationDataset +from .regression_dataset import RegressionDataset \ No newline at end of file diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py new file mode 100644 index 0000000..36bbab7 --- /dev/null +++ b/abl/dataset/bridge_dataset.py @@ -0,0 +1,55 @@ +from torch.utils.data import Dataset +from typing import List, Any, Tuple + + +class BridgeDataset(Dataset): + def __init__(self, X: List[Any], Z: List[Any], Y: List[Any]): + """Initialize a basic dataset. + + Parameters + ---------- + X : List[Any] + A list of objects representing the input data. + Z : List[Any] + A list of objects representing the symbol. + Y : List[Any] + A list of objects representing the label. + """ + self.X = X + self.Z = Z + self.Y = Y + + if self.Z is None: + self.Z = [None] * len(self.X) + + def __len__(self): + """Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get an item from the dataset. + + Parameters + ---------- + index : int + The index of the item to retrieve. + + Returns + ------- + Tuple[Any, Any] + A tuple containing the input and output data at the specified index. + """ + if index >= len(self): + raise ValueError("index range error") + + X = self.X[index] + Z = self.Z[index] + Y = self.Y[index] + + return (X, Z, Y) \ No newline at end of file diff --git a/abl/dataset/classification_dataset.py b/abl/dataset/classification_dataset.py new file mode 100644 index 0000000..62c725d --- /dev/null +++ b/abl/dataset/classification_dataset.py @@ -0,0 +1,60 @@ +import torch +from torch.utils.data import Dataset +from typing import List, Any, Tuple, Callable + + +class ClassificationDataset(Dataset): + def __init__( + self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None + ): + """ + Initialize the dataset used for classification task. + + Parameters + ---------- + X : List[Any] + The input data. + Y : List[int] + The target data. + transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version. Defaults to None. + """ + self.X = X + self.Y = torch.LongTensor(Y) + self.transform = transform + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: + """ + Get the item at the given index. + + Parameters + ---------- + index : int + The index of the item to get. + + Returns + ------- + Tuple[Any, torch.Tensor] + A tuple containing the object and its label. + """ + if index >= len(self): + raise ValueError("index range error") + + x = self.X[index] + if self.transform is not None: + x = self.transform(x) + + y = self.Y[index] + + return x, y diff --git a/abl/dataset/regression_dataset.py b/abl/dataset/regression_dataset.py new file mode 100644 index 0000000..a0fc769 --- /dev/null +++ b/abl/dataset/regression_dataset.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import Dataset +from typing import List, Any, Tuple + + +class RegressionDataset(Dataset): + def __init__(self, X: List[Any], Y: List[Any]): + """Initialize a basic dataset. + + Parameters + ---------- + X : List[Any] + A list of objects representing the input data. + Y : List[Any] + A list of objects representing the output data. + """ + self.X = X + self.Y = Y + + def __len__(self): + """Return the length of the dataset. + + Returns + ------- + int + The length of the dataset. + """ + return len(self.X) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get an item from the dataset. + + Parameters + ---------- + index : int + The index of the item to retrieve. + + Returns + ------- + Tuple[Any, Any] + A tuple containing the input and output data at the specified index. + """ + if index >= len(self): + raise ValueError("index range error") + + x = self.X[index] + y = self.Y[index] + + return x, y