Browse Source

[MNT] add type and length checking to dataset

pull/3/head
Gao Enhao 2 years ago
parent
commit
ea1f0d004a
3 changed files with 15 additions and 0 deletions
  1. +5
    -0
      abl/dataset/bridge_dataset.py
  2. +5
    -0
      abl/dataset/classification_dataset.py
  3. +5
    -0
      abl/dataset/regression_dataset.py

+ 5
- 0
abl/dataset/bridge_dataset.py View File

@@ -15,6 +15,11 @@ class BridgeDataset(Dataset):
Y : List[Any]
A list of objects representing the label.
"""
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Z = Z
self.Y = Y


+ 5
- 0
abl/dataset/classification_dataset.py View File

@@ -19,6 +19,11 @@ class ClassificationDataset(Dataset):
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Y = torch.LongTensor(Y)
self.transform = transform


+ 5
- 0
abl/dataset/regression_dataset.py View File

@@ -14,6 +14,11 @@ class RegressionDataset(Dataset):
Y : List[Any]
A list of objects representing the output data.
"""
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Y = Y



Loading…
Cancel
Save