diff --git a/gpu_mnist_example/inference.py b/gpu_mnist_example/inference.py index 6c78486..add58c2 100644 --- a/gpu_mnist_example/inference.py +++ b/gpu_mnist_example/inference.py @@ -53,6 +53,7 @@ def test(model, test_loader, data_length): test_loss /= (i+1) # 结果写入输出文件夹 + print('accuracy: {:.2f}'.format(correct / data_length)) filename = 'result.txt' file_path = os.path.join('/tmp/output', filename) with open(file_path, 'w') as file: @@ -74,9 +75,9 @@ if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = args.batch_size epochs = args.epoch_size - test_dataset = mnist.MNIST(root=mnist_example_test2_model_djts_path + "/test", train=False, transform=ToTensor(),download=False) + test_dataset = mnist.MNIST(root=MnistDataset_torch + "/test", train=False, transform=ToTensor(),download=False) test_loader = DataLoader(test_dataset, batch_size=batch_size) model = Model().to(device) - checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1_0.73.pkl") + checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1.pkl") model.load_state_dict(checkpoint['model']) test(model,test_loader,len(test_dataset)) \ No newline at end of file