Browse Source

[MNT] remove abl/models/lenet5.py

pull/3/head
Gao Enhao 3 years ago
parent
commit
85c39242a4
1 changed files with 0 additions and 83 deletions
  1. +0
    -83
      abl/models/lenet5.py

+ 0
- 83
abl/models/lenet5.py View File

@@ -1,83 +0,0 @@
# coding: utf-8
#================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :lenet5.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/03/03
# Description :
#
#================================================================#

import sys
sys.path.append("..")

import torchvision

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
import numpy as np

from models.basic_model import BasicModel
import utils.plog as plog

class LeNet5(nn.Module):
def __init__(self, num_classes=10, image_size=(28, 28)):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
self.conv2 = nn.Conv2d(6, 16, 3)
self.conv3 = nn.Conv2d(16, 16, 3)

feature_map_size = ((np.array(image_size) // 2 - 2) // 2 - 2)
num_features = 16 * feature_map_size[0] * feature_map_size[1]

self.fc1 = nn.Linear(num_features, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)

def forward(self, x):
'''前向传播函数'''
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
x = F.relu(self.conv3(x))
x = x.view(-1, self.num_flat_features(x))
#print(x.size())
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features


class SymbolNet(nn.Module):
def __init__(self, num_classes=14):
super(SymbolNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(30976, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x

Loading…
Cancel
Save