Browse Source

update model_name

pull/12/head
liuzx 1 year ago
parent
commit
fb8709cb74
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      gpgpu_mnist_example/inference.py

+ 1
- 1
gpgpu_mnist_example/inference.py View File

@@ -83,7 +83,7 @@ if __name__ == '__main__':
test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model().to(device)
checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1_0.73.pkl")
checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1.pkl")
model.load_state_dict(checkpoint['model'])
test(model,test_loader,len(test_dataset),output_path)
upload_output()

Loading…
Cancel
Save