Browse Source

[ENH] add three kinds of basic dataset to dataset folder

pull/3/head
Gao Enhao 2 years ago
parent
commit
74cc37d6a4
4 changed files with 167 additions and 1 deletions
  1. +3
    -1
      abl/dataset/__init__.py
  2. +55
    -0
      abl/dataset/bridge_dataset.py
  3. +60
    -0
      abl/dataset/classification_dataset.py
  4. +49
    -0
      abl/dataset/regression_dataset.py

+ 3
- 1
abl/dataset/__init__.py View File

@@ -1 +1,3 @@
from .base_dataset import BaseDataset
from .bridge_dataset import BridgeDataset
from .classification_dataset import ClassificationDataset
from .regression_dataset import RegressionDataset

+ 55
- 0
abl/dataset/bridge_dataset.py View File

@@ -0,0 +1,55 @@
from torch.utils.data import Dataset
from typing import List, Any, Tuple


class BridgeDataset(Dataset):
def __init__(self, X: List[Any], Z: List[Any], Y: List[Any]):
"""Initialize a basic dataset.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Z : List[Any]
A list of objects representing the symbol.
Y : List[Any]
A list of objects representing the label.
"""
self.X = X
self.Z = Z
self.Y = Y

if self.Z is None:
self.Z = [None] * len(self.X)

def __len__(self):
"""Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get an item from the dataset.

Parameters
----------
index : int
The index of the item to retrieve.

Returns
-------
Tuple[Any, Any]
A tuple containing the input and output data at the specified index.
"""
if index >= len(self):
raise ValueError("index range error")

X = self.X[index]
Z = self.Z[index]
Y = self.Y[index]

return (X, Z, Y)

+ 60
- 0
abl/dataset/classification_dataset.py View File

@@ -0,0 +1,60 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple, Callable


class ClassificationDataset(Dataset):
def __init__(
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None
):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
Y : List[int]
The target data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
self.X = X
self.Y = torch.LongTensor(Y)
self.transform = transform

def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
if self.transform is not None:
x = self.transform(x)

y = self.Y[index]

return x, y

+ 49
- 0
abl/dataset/regression_dataset.py View File

@@ -0,0 +1,49 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple


class RegressionDataset(Dataset):
def __init__(self, X: List[Any], Y: List[Any]):
"""Initialize a basic dataset.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Y : List[Any]
A list of objects representing the output data.
"""
self.X = X
self.Y = Y

def __len__(self):
"""Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get an item from the dataset.

Parameters
----------
index : int
The index of the item to retrieve.

Returns
-------
Tuple[Any, Any]
A tuple containing the input and output data at the specified index.
"""
if index >= len(self):
raise ValueError("index range error")

x = self.X[index]
y = self.Y[index]

return x, y

Loading…
Cancel
Save