|
|
|
@@ -20,21 +20,24 @@ from torch import nn |
|
|
|
from torch.nn import functional as F |
|
|
|
from torch.autograd import Variable |
|
|
|
import torchvision.transforms as transforms |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from models.basic_model import BasicModel |
|
|
|
|
|
|
|
import utils.plog as plog |
|
|
|
|
|
|
|
class LeNet5(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
def __init__(self, num_classes=10, image_size=(28, 28)): |
|
|
|
super().__init__() |
|
|
|
self.conv1 = nn.Conv2d(1, 6, 3, padding=1) |
|
|
|
self.conv2 = nn.Conv2d(6, 16, 3) |
|
|
|
self.conv3 = nn.Conv2d(16, 16, 3) |
|
|
|
|
|
|
|
self.fc1 = nn.Linear(256, 120) |
|
|
|
feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2) |
|
|
|
num_features = 16 * feature_map_size[0] * feature_map_size[1] |
|
|
|
|
|
|
|
self.fc1 = nn.Linear(num_features, 120) |
|
|
|
self.fc2 = nn.Linear(120, 84) |
|
|
|
self.fc3 = nn.Linear(84, 10) |
|
|
|
self.fc3 = nn.Linear(84, num_classes) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
'''前向传播函数''' |
|
|
|
|