| @@ -20,13 +20,12 @@ import os | |||
| from multiprocessing import Pool | |||
| class XYDataset(Dataset): | |||
| def __init__(self, X, Y, transform=None, target_transform=None): | |||
| def __init__(self, X, Y, transform=None): | |||
| self.X = X | |||
| self.Y = Y | |||
| self.Y = torch.LongTensor(Y) | |||
| self.n_sample = len(X) | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| def __len__(self): | |||
| return len(self.X) | |||
| @@ -39,8 +38,6 @@ class XYDataset(Dataset): | |||
| img = self.transform(img) | |||
| label = self.Y[index] | |||
| if self.target_transform is not None: | |||
| label = self.target_transform(label) | |||
| return (img, label) | |||
| @@ -64,7 +61,6 @@ class BasicModel(): | |||
| save_interval = None, | |||
| save_dir = None, | |||
| transform = None, | |||
| target_transform = None, | |||
| collate_fn = None, | |||
| recorder = None): | |||
| @@ -78,7 +74,6 @@ class BasicModel(): | |||
| self.criterion = criterion | |||
| self.optimizer = optimizer | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| self.device = device | |||
| if recorder is None: | |||
| @@ -114,9 +109,8 @@ class BasicModel(): | |||
| if data_loader is None: | |||
| collate_fn = self.collate_fn | |||
| transform = self.transform | |||
| target_transform = self.target_transform | |||
| train_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) | |||
| train_dataset = XYDataset(X, y, transform=transform) | |||
| sampler = None | |||
| data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ | |||
| shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | |||
| @@ -166,10 +160,9 @@ class BasicModel(): | |||
| if data_loader is None: | |||
| collate_fn = self.collate_fn | |||
| transform = self.transform | |||
| target_transform = self.target_transform | |||
| Y = [0] * len(X) | |||
| val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) | |||
| val_dataset = XYDataset(X, Y, transform=transform) | |||
| sampler = None | |||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | |||
| shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | |||
| @@ -184,10 +177,9 @@ class BasicModel(): | |||
| if data_loader is None: | |||
| collate_fn = self.collate_fn | |||
| transform = self.transform | |||
| target_transform = self.target_transform | |||
| Y = [0] * len(X) | |||
| val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) | |||
| val_dataset = XYDataset(X, Y, transform=transform) | |||
| sampler = None | |||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | |||
| shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | |||
| @@ -231,9 +223,8 @@ class BasicModel(): | |||
| if data_loader is None: | |||
| collate_fn = self.collate_fn | |||
| transform = self.transform | |||
| target_transform = self.target_transform | |||
| val_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) | |||
| val_dataset = XYDataset(X, y, transform=transform) | |||
| sampler = None | |||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | |||
| shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | |||