diff --git a/gpgpu_mnist_example/inference.py b/gpgpu_mnist_example/inference.py index 2790d38..8aef446 100644 --- a/gpgpu_mnist_example/inference.py +++ b/gpgpu_mnist_example/inference.py @@ -83,7 +83,7 @@ if __name__ == '__main__': test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False) test_loader = DataLoader(test_dataset, batch_size=batch_size) model = Model().to(device) - checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1_0.73.pkl") + checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1.pkl") model.load_state_dict(checkpoint['model']) test(model,test_loader,len(test_dataset),output_path) upload_output() \ No newline at end of file