diff --git a/models/basic_model.py b/models/basic_model.py index 86fbea5..182e1e5 100644 --- a/models/basic_model.py +++ b/models/basic_model.py @@ -131,21 +131,19 @@ class BasicModel(): model.train() - loss_value = 0 - for _, data in enumerate(data_loader): - X = data[0].to(device) - Y = data[1].to(device) - pred_Y = model(X) - - loss = criterion(pred_Y, Y) + total_loss, total_num = 0.0, 0 + for data, target in data_loader: + data, target = data.to(device), target.to(device) + out = model(data) + loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() - loss_value += loss.item() + total_loss += loss.item() * data.size(0) - return loss_value + return total_loss / total_num def _predict(self, data_loader): model = self.model @@ -155,9 +153,9 @@ class BasicModel(): with torch.no_grad(): results = [] - for _, data in enumerate(data_loader): - X = data[0].to(device) - pred_Y = model(X) + for data, _ in data_loader: + data = data.to(device) + pred_Y = model(data) results.append(pred_Y) return torch.cat(results, axis=0)