diff --git a/abl/models/lenet5.py b/abl/models/lenet5.py deleted file mode 100644 index 1016d2c..0000000 --- a/abl/models/lenet5.py +++ /dev/null @@ -1,83 +0,0 @@ -# coding: utf-8 -#================================================================# -# Copyright (C) 2021 Freecss All rights reserved. -# -# File Name :lenet5.py -# Author :freecss -# Email :karlfreecss@gmail.com -# Created Date :2021/03/03 -# Description : -# -#================================================================# - -import sys -sys.path.append("..") - -import torchvision - -import torch -from torch import nn -from torch.nn import functional as F -from torch.autograd import Variable -import torchvision.transforms as transforms -import numpy as np - -from models.basic_model import BasicModel -import utils.plog as plog - -class LeNet5(nn.Module): - def __init__(self, num_classes=10, image_size=(28, 28)): - super().__init__() - self.conv1 = nn.Conv2d(1, 6, 3, padding=1) - self.conv2 = nn.Conv2d(6, 16, 3) - self.conv3 = nn.Conv2d(16, 16, 3) - - feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2) - num_features = 16 * feature_map_size[0] * feature_map_size[1] - - self.fc1 = nn.Linear(num_features, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, num_classes) - - def forward(self, x): - '''前向传播函数''' - x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) - x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2)) - x = F.relu(self.conv3(x)) - x = x.view(-1, self.num_flat_features(x)) - #print(x.size()) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - def num_flat_features(self, x): - size = x.size()[1:] - num_features = 1 - for s in size: - num_features *= s - return num_features - - -class SymbolNet(nn.Module): - def __init__(self, num_classes=14): - super(SymbolNet, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) - self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) - self.dropout1 = nn.Dropout2d(0.25) - self.dropout2 = nn.Dropout2d(0.5) - self.fc1 = nn.Linear(30976, 128) - self.fc2 = nn.Linear(128, num_classes) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - return x