Browse Source

update lenet5.py

pull/3/head
Gao Enhao 3 years ago
parent
commit
a61ee9cd54
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      models/lenet5.py

+ 7
- 4
models/lenet5.py View File

@@ -20,21 +20,24 @@ from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
import numpy as np

from models.basic_model import BasicModel

import utils.plog as plog

class LeNet5(nn.Module):
def __init__(self):
def __init__(self, num_classes=10, image_size=(28, 28)):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
self.conv2 = nn.Conv2d(6, 16, 3)
self.conv3 = nn.Conv2d(16, 16, 3)

self.fc1 = nn.Linear(256, 120)
feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2)
num_features = 16 * feature_map_size[0] * feature_map_size[1]

self.fc1 = nn.Linear(num_features, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.fc3 = nn.Linear(84, num_classes)

def forward(self, x):
'''前向传播函数'''


Loading…
Cancel
Save