Browse Source

update basic_model.py

pull/3/head
Gao Enhao 3 years ago
parent
commit
06830d5a10
1 changed files with 6 additions and 15 deletions
  1. +6
    -15
      models/basic_model.py

+ 6
- 15
models/basic_model.py View File

@@ -20,13 +20,12 @@ import os
from multiprocessing import Pool from multiprocessing import Pool


class XYDataset(Dataset): class XYDataset(Dataset):
def __init__(self, X, Y, transform=None, target_transform=None):
def __init__(self, X, Y, transform=None):
self.X = X self.X = X
self.Y = Y
self.Y = torch.LongTensor(Y)


self.n_sample = len(X) self.n_sample = len(X)
self.transform = transform self.transform = transform
self.target_transform = target_transform


def __len__(self): def __len__(self):
return len(self.X) return len(self.X)
@@ -39,8 +38,6 @@ class XYDataset(Dataset):
img = self.transform(img) img = self.transform(img)


label = self.Y[index] label = self.Y[index]
if self.target_transform is not None:
label = self.target_transform(label)


return (img, label) return (img, label)


@@ -64,7 +61,6 @@ class BasicModel():
save_interval = None, save_interval = None,
save_dir = None, save_dir = None,
transform = None, transform = None,
target_transform = None,
collate_fn = None, collate_fn = None,
recorder = None): recorder = None):


@@ -78,7 +74,6 @@ class BasicModel():
self.criterion = criterion self.criterion = criterion
self.optimizer = optimizer self.optimizer = optimizer
self.transform = transform self.transform = transform
self.target_transform = target_transform
self.device = device self.device = device


if recorder is None: if recorder is None:
@@ -114,9 +109,8 @@ class BasicModel():
if data_loader is None: if data_loader is None:
collate_fn = self.collate_fn collate_fn = self.collate_fn
transform = self.transform transform = self.transform
target_transform = self.target_transform


train_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform)
train_dataset = XYDataset(X, y, transform=transform)
sampler = None sampler = None
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \ data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \
shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \
@@ -166,10 +160,9 @@ class BasicModel():
if data_loader is None: if data_loader is None:
collate_fn = self.collate_fn collate_fn = self.collate_fn
transform = self.transform transform = self.transform
target_transform = self.target_transform


Y = [0] * len(X) Y = [0] * len(X)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, Y, transform=transform)
sampler = None sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \
@@ -184,10 +177,9 @@ class BasicModel():
if data_loader is None: if data_loader is None:
collate_fn = self.collate_fn collate_fn = self.collate_fn
transform = self.transform transform = self.transform
target_transform = self.target_transform


Y = [0] * len(X) Y = [0] * len(X)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, Y, transform=transform)
sampler = None sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \
@@ -231,9 +223,8 @@ class BasicModel():
if data_loader is None: if data_loader is None:
collate_fn = self.collate_fn collate_fn = self.collate_fn
transform = self.transform transform = self.transform
target_transform = self.target_transform


val_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, y, transform=transform)
sampler = None sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \ data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \ shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \


Loading…
Cancel
Save