|
|
|
@@ -307,8 +307,9 @@ class BasicNN: |
|
|
|
data_loader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
num_workers=int(self.num_workers), |
|
|
|
num_workers=self.num_workers, |
|
|
|
collate_fn=self.collate_fn, |
|
|
|
pin_memory=torch.cuda.is_available() |
|
|
|
) |
|
|
|
return self._predict(data_loader).argmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
@@ -348,8 +349,9 @@ class BasicNN: |
|
|
|
data_loader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
num_workers=int(self.num_workers), |
|
|
|
num_workers=self.num_workers, |
|
|
|
collate_fn=self.collate_fn, |
|
|
|
pin_memory=torch.cuda.is_available() |
|
|
|
) |
|
|
|
return self._predict(data_loader).softmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
@@ -381,11 +383,9 @@ class BasicNN: |
|
|
|
model.eval() |
|
|
|
|
|
|
|
total_correct_num, total_num, total_loss = 0, 0, 0.0 |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for data, target in data_loader: |
|
|
|
data, target = data.to(device), target.to(device) |
|
|
|
|
|
|
|
out = model(data) |
|
|
|
|
|
|
|
if len(out.shape) > 1: |
|
|
|
@@ -482,8 +482,9 @@ class BasicNN: |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
shuffle=shuffle, |
|
|
|
num_workers=int(self.num_workers), |
|
|
|
num_workers=self.num_workers, |
|
|
|
collate_fn=self.collate_fn, |
|
|
|
pin_memory=torch.cuda.is_available() |
|
|
|
) |
|
|
|
return data_loader |
|
|
|
|
|
|
|
|