| @@ -52,48 +52,32 @@ class LeNet5(nn.Module): | |||
| return x | |||
| def num_flat_features(self, x): | |||
| #x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size | |||
| size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字 | |||
| size = x.size()[1:] | |||
| num_features = 1 | |||
| for s in size: | |||
| num_features *= s | |||
| return num_features | |||
| class Params: | |||
| imgH = 28 | |||
| imgW = 28 | |||
| keep_ratio = True | |||
| saveInterval = 10 | |||
| batchSize = 16 | |||
| num_workers = 16 | |||
| def get_data(): #数据预处理 | |||
| transform = transforms.Compose([transforms.ToTensor(), | |||
| transforms.Normalize((0.5), (0.5))]) | |||
| #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |||
| #训练集 | |||
| train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) | |||
| train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16) | |||
| #测试集 | |||
| test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True) | |||
| test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16) | |||
| classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') | |||
| return train_loader, test_loader, classes | |||
| if __name__ == "__main__": | |||
| recorder = plog.ResultRecorder() | |||
| cls = LeNet5() | |||
| criterion = nn.CrossEntropyLoss(size_average=True) | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder) | |||
| train_loader, test_loader, classes = get_data() | |||
| #model.val(test_loader, print_prefix = "before training") | |||
| model.fit(train_loader, n_epoch = 100) | |||
| model.val(test_loader, print_prefix = "after trained") | |||
| res = model.predict(test_loader, print_prefix = "predict") | |||
| print(res.argmax(axis=1)[:10]) | |||
| class SymbolNet(nn.Module): | |||
| def __init__(self, num_classes=14): | |||
| super(SymbolNet, self).__init__() | |||
| self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1) | |||
| self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1) | |||
| self.dropout1 = nn.Dropout2d(0.25) | |||
| self.dropout2 = nn.Dropout2d(0.5) | |||
| self.fc1 = nn.Linear(30976, 128) | |||
| self.fc2 = nn.Linear(128, num_classes) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = F.relu(x) | |||
| x = self.conv2(x) | |||
| x = F.max_pool2d(x, 2) | |||
| x = self.dropout1(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc1(x) | |||
| x = F.relu(x) | |||
| x = self.dropout2(x) | |||
| x = self.fc2(x) | |||
| return x | |||