Browse Source

Create lenet5.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
f78c69ae10
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 38 deletions
  1. +22
    -38
      models/lenet5.py

+ 22
- 38
models/lenet5.py View File

@@ -52,48 +52,32 @@ class LeNet5(nn.Module):
return x

def num_flat_features(self, x):
#x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size
size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features

class Params:
imgH = 28
imgW = 28
keep_ratio = True
saveInterval = 10
batchSize = 16
num_workers = 16

def get_data(): #数据预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))])
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#训练集
train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16)
#测试集
test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

return train_loader, test_loader, classes

if __name__ == "__main__":
recorder = plog.ResultRecorder()
cls = LeNet5()
criterion = nn.CrossEntropyLoss(size_average=True)
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder)
train_loader, test_loader, classes = get_data()

#model.val(test_loader, print_prefix = "before training")
model.fit(train_loader, n_epoch = 100)
model.val(test_loader, print_prefix = "after trained")
res = model.predict(test_loader, print_prefix = "predict")
print(res.argmax(axis=1)[:10])
class SymbolNet(nn.Module):
def __init__(self, num_classes=14):
super(SymbolNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(30976, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x

Loading…
Cancel
Save