From 15231012669e0fffe8e684e01aa9bf92d2834e7c Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Mon, 21 Nov 2022 15:23:26 +0800 Subject: [PATCH] update basic model --- models/basic_model.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/models/basic_model.py b/models/basic_model.py index 86fbea5..182e1e5 100644 --- a/models/basic_model.py +++ b/models/basic_model.py @@ -131,21 +131,19 @@ class BasicModel(): model.train() - loss_value = 0 - for _, data in enumerate(data_loader): - X = data[0].to(device) - Y = data[1].to(device) - pred_Y = model(X) - - loss = criterion(pred_Y, Y) + total_loss, total_num = 0.0, 0 + for data, target in data_loader: + data, target = data.to(device), target.to(device) + out = model(data) + loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() - loss_value += loss.item() + total_loss += loss.item() * data.size(0) - return loss_value + return total_loss / total_num def _predict(self, data_loader): model = self.model @@ -155,9 +153,9 @@ class BasicModel(): with torch.no_grad(): results = [] - for _, data in enumerate(data_loader): - X = data[0].to(device) - pred_Y = model(X) + for data, _ in data_loader: + data = data.to(device) + pred_Y = model(data) results.append(pred_Y) return torch.cat(results, axis=0)