diff --git a/models/lenet5.py b/models/lenet5.py index 9d5b054..56c1ca6 100644 --- a/models/lenet5.py +++ b/models/lenet5.py @@ -20,21 +20,24 @@ from torch import nn from torch.nn import functional as F from torch.autograd import Variable import torchvision.transforms as transforms +import numpy as np from models.basic_model import BasicModel - import utils.plog as plog class LeNet5(nn.Module): - def __init__(self): + def __init__(self, num_classes=10, image_size=(28, 28)): super().__init__() self.conv1 = nn.Conv2d(1, 6, 3, padding=1) self.conv2 = nn.Conv2d(6, 16, 3) self.conv3 = nn.Conv2d(16, 16, 3) - self.fc1 = nn.Linear(256, 120) + feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2) + num_features = 16 * feature_map_size[0] * feature_map_size[1] + + self.fc1 = nn.Linear(num_features, 120) self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) + self.fc3 = nn.Linear(84, num_classes) def forward(self, x): '''前向传播函数'''