diff --git a/models/basic_model.py b/models/basic_model.py index 86fbea5..669474f 100644 --- a/models/basic_model.py +++ b/models/basic_model.py @@ -20,13 +20,12 @@ import os from multiprocessing import Pool class XYDataset(Dataset): - def __init__(self, X, Y, transform=None, target_transform=None): + def __init__(self, X, Y, transform=None): self.X = X - self.Y = Y + self.Y = torch.LongTensor(Y) self.n_sample = len(X) self.transform = transform - self.target_transform = target_transform def __len__(self): return len(self.X) @@ -39,8 +38,6 @@ class XYDataset(Dataset): img = self.transform(img) label = self.Y[index] - if self.target_transform is not None: - label = self.target_transform(label) return (img, label) @@ -64,7 +61,6 @@ class BasicModel(): save_interval = None, save_dir = None, transform = None, - target_transform = None, collate_fn = None, recorder = None): @@ -78,7 +74,6 @@ class BasicModel(): self.criterion = criterion self.optimizer = optimizer self.transform = transform - self.target_transform = target_transform self.device = device if recorder is None: @@ -114,9 +109,8 @@ class BasicModel(): if data_loader is None: collate_fn = self.collate_fn transform = self.transform - target_transform = self.target_transform - train_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) + train_dataset = XYDataset(X, y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ @@ -166,10 +160,9 @@ class BasicModel(): if data_loader is None: collate_fn = self.collate_fn transform = self.transform - target_transform = self.target_transform Y = [0] * len(X) - val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) + val_dataset = XYDataset(X, Y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ @@ -184,10 +177,9 @@ class BasicModel(): if data_loader is None: collate_fn = self.collate_fn transform = self.transform - target_transform = self.target_transform Y = [0] * len(X) - val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform) + val_dataset = XYDataset(X, Y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ @@ -231,9 +223,8 @@ class BasicModel(): if data_loader is None: collate_fn = self.collate_fn transform = self.transform - target_transform = self.target_transform - val_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform) + val_dataset = XYDataset(X, y, transform=transform) sampler = None data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \