| @@ -20,13 +20,12 @@ import os | |||||
| from multiprocessing import Pool | from multiprocessing import Pool | ||||
| class XYDataset(Dataset): | class XYDataset(Dataset): | ||||
| def __init__(self, X, Y, transform=None, target_transform=None): | |||||
| def __init__(self, X, Y, transform=None): | |||||
| self.X = X | self.X = X | ||||
| self.Y = Y | |||||
| self.Y = torch.LongTensor(Y) | |||||
| self.n_sample = len(X) | self.n_sample = len(X) | ||||
| self.transform = transform | self.transform = transform | ||||
| self.target_transform = target_transform | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.X) | return len(self.X) | ||||
| @@ -39,8 +38,6 @@ class XYDataset(Dataset): | |||||
| img = self.transform(img) | img = self.transform(img) | ||||
| label = self.Y[index] | label = self.Y[index] | ||||
| if self.target_transform is not None: | |||||
| label = self.target_transform(label) | |||||
| return (img, label) | return (img, label) | ||||
| @@ -64,7 +61,6 @@ class BasicModel(): | |||||
| save_interval = None, | save_interval = None, | ||||
| save_dir = None, | save_dir = None, | ||||
| transform = None, | transform = None, | ||||
| target_transform = None, | |||||
| collate_fn = None, | collate_fn = None, | ||||
| recorder = None): | recorder = None): | ||||
| @@ -78,7 +74,6 @@ class BasicModel(): | |||||
| self.criterion = criterion | self.criterion = criterion | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.transform = transform | self.transform = transform | ||||
| self.target_transform = target_transform | |||||
| self.device = device | self.device = device | ||||
| if recorder is None: | if recorder is None: | ||||
| @@ -114,9 +109,8 @@ class BasicModel(): | |||||
| if data_loader is None: | if data_loader is None: | ||||
| collate_fn = self.collate_fn | collate_fn = self.collate_fn | ||||
| transform = self.transform | transform = self.transform | ||||
| target_transform = self.target_transform | |||||
| train_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) | |||||
| train_dataset = XYDataset(X, y, transform=transform) | |||||
| sampler = None | sampler = None | ||||
| data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ | data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ | ||||
| shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | ||||
| @@ -166,10 +160,9 @@ class BasicModel(): | |||||
| if data_loader is None: | if data_loader is None: | ||||
| collate_fn = self.collate_fn | collate_fn = self.collate_fn | ||||
| transform = self.transform | transform = self.transform | ||||
| target_transform = self.target_transform | |||||
| Y = [0] * len(X) | Y = [0] * len(X) | ||||
| val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) | |||||
| val_dataset = XYDataset(X, Y, transform=transform) | |||||
| sampler = None | sampler = None | ||||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | ||||
| shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | ||||
| @@ -184,10 +177,9 @@ class BasicModel(): | |||||
| if data_loader is None: | if data_loader is None: | ||||
| collate_fn = self.collate_fn | collate_fn = self.collate_fn | ||||
| transform = self.transform | transform = self.transform | ||||
| target_transform = self.target_transform | |||||
| Y = [0] * len(X) | Y = [0] * len(X) | ||||
| val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) | |||||
| val_dataset = XYDataset(X, Y, transform=transform) | |||||
| sampler = None | sampler = None | ||||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | ||||
| shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | ||||
| @@ -231,9 +223,8 @@ class BasicModel(): | |||||
| if data_loader is None: | if data_loader is None: | ||||
| collate_fn = self.collate_fn | collate_fn = self.collate_fn | ||||
| transform = self.transform | transform = self.transform | ||||
| target_transform = self.target_transform | |||||
| val_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) | |||||
| val_dataset = XYDataset(X, y, transform=transform) | |||||
| sampler = None | sampler = None | ||||
| data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ | ||||
| shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ | ||||