Browse Source

update basic_model.py

pull/3/head
Gao Enhao 3 years ago
parent
commit
06830d5a10
1 changed files with 6 additions and 15 deletions
  1. +6
    -15
      models/basic_model.py

+ 6
- 15
models/basic_model.py View File

@@ -20,13 +20,12 @@ import os
from multiprocessing import Pool

class XYDataset(Dataset):
def __init__(self, X, Y, transform=None, target_transform=None):
def __init__(self, X, Y, transform=None):
self.X = X
self.Y = Y
self.Y = torch.LongTensor(Y)

self.n_sample = len(X)
self.transform = transform
self.target_transform = target_transform

def __len__(self):
return len(self.X)
@@ -39,8 +38,6 @@ class XYDataset(Dataset):
img = self.transform(img)

label = self.Y[index]
if self.target_transform is not None:
label = self.target_transform(label)

return (img, label)

@@ -64,7 +61,6 @@ class BasicModel():
save_interval = None,
save_dir = None,
transform = None,
target_transform = None,
collate_fn = None,
recorder = None):

@@ -78,7 +74,6 @@ class BasicModel():
self.criterion = criterion
self.optimizer = optimizer
self.transform = transform
self.target_transform = target_transform
self.device = device

if recorder is None:
@@ -114,9 +109,8 @@ class BasicModel():
if data_loader is None:
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

train_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform)
train_dataset = XYDataset(X, y, transform=transform)
sampler = None
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, \
shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \
@@ -166,10 +160,9 @@ class BasicModel():
if data_loader is None:
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = [0] * len(X)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, Y, transform=transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \
@@ -184,10 +177,9 @@ class BasicModel():
if data_loader is None:
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

Y = [0] * len(X)
val_dataset = XYDataset(X, Y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, Y, transform=transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \
@@ -231,9 +223,8 @@ class BasicModel():
if data_loader is None:
collate_fn = self.collate_fn
transform = self.transform
target_transform = self.target_transform

val_dataset = XYDataset(X, y, transform=transform, target_transform=target_transform)
val_dataset = XYDataset(X, y, transform=transform)
sampler = None
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, \
shuffle=True, sampler=sampler, num_workers=int(self.num_workers), \


Loading…
Cancel
Save