|
|
|
@@ -19,6 +19,11 @@ class ClassificationDataset(Dataset): |
|
|
|
transform : Callable[..., Any], optional |
|
|
|
A function/transform that takes in an object and returns a transformed version. Defaults to None. |
|
|
|
""" |
|
|
|
if (not isinstance(X, list)) or (not isinstance(Y, list)): |
|
|
|
raise ValueError("X and Y should be of type list.") |
|
|
|
if len(X) != len(Y): |
|
|
|
raise ValueError("Length of X and Y must be equal.") |
|
|
|
|
|
|
|
self.X = X |
|
|
|
self.Y = torch.LongTensor(Y) |
|
|
|
self.transform = transform |
|
|
|
|