|
|
|
@@ -83,7 +83,7 @@ if __name__ == '__main__': |
|
|
|
test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False) |
|
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
model = Model().to(device) |
|
|
|
checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1_0.73.pkl") |
|
|
|
checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1.pkl") |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
test(model,test_loader,len(test_dataset),output_path) |
|
|
|
upload_output() |