diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py index 36bbab7..bb0ce98 100644 --- a/abl/dataset/bridge_dataset.py +++ b/abl/dataset/bridge_dataset.py @@ -15,6 +15,11 @@ class BridgeDataset(Dataset): Y : List[Any] A list of objects representing the label. """ + if (not isinstance(X, list)) or (not isinstance(Y, list)): + raise ValueError("X and Y should be of type list.") + if len(X) != len(Y): + raise ValueError("Length of X and Y must be equal.") + self.X = X self.Z = Z self.Y = Y diff --git a/abl/dataset/classification_dataset.py b/abl/dataset/classification_dataset.py index 62c725d..28f9299 100644 --- a/abl/dataset/classification_dataset.py +++ b/abl/dataset/classification_dataset.py @@ -19,6 +19,11 @@ class ClassificationDataset(Dataset): transform : Callable[..., Any], optional A function/transform that takes in an object and returns a transformed version. Defaults to None. """ + if (not isinstance(X, list)) or (not isinstance(Y, list)): + raise ValueError("X and Y should be of type list.") + if len(X) != len(Y): + raise ValueError("Length of X and Y must be equal.") + self.X = X self.Y = torch.LongTensor(Y) self.transform = transform diff --git a/abl/dataset/regression_dataset.py b/abl/dataset/regression_dataset.py index a0fc769..8cf136c 100644 --- a/abl/dataset/regression_dataset.py +++ b/abl/dataset/regression_dataset.py @@ -14,6 +14,11 @@ class RegressionDataset(Dataset): Y : List[Any] A list of objects representing the output data. """ + if (not isinstance(X, list)) or (not isinstance(Y, list)): + raise ValueError("X and Y should be of type list.") + if len(X) != len(Y): + raise ValueError("Length of X and Y must be equal.") + self.X = X self.Y = Y