From a61ee9cd54d55f2ee3b74e27f7cd21655234f8c9 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 17 Nov 2022 21:42:14 +0800 Subject: [PATCH] update lenet5.py --- models/lenet5.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/lenet5.py b/models/lenet5.py index 9d5b054..56c1ca6 100644 --- a/models/lenet5.py +++ b/models/lenet5.py @@ -20,21 +20,24 @@ from torch import nn from torch.nn import functional as F from torch.autograd import Variable import torchvision.transforms as transforms +import numpy as np from models.basic_model import BasicModel - import utils.plog as plog class LeNet5(nn.Module): - def __init__(self): + def __init__(self, num_classes=10, image_size=(28, 28)): super().__init__() self.conv1 = nn.Conv2d(1, 6, 3, padding=1) self.conv2 = nn.Conv2d(6, 16, 3) self.conv3 = nn.Conv2d(16, 16, 3) - self.fc1 = nn.Linear(256, 120) + feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2) + num_features = 16 * feature_map_size[0] * feature_map_size[1] + + self.fc1 = nn.Linear(num_features, 120) self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) + self.fc3 = nn.Linear(84, num_classes) def forward(self, x): '''前向传播函数'''