Browse Source

update

liuzx-patch-1
liuzx 2 years ago
parent
commit
72b1c5a62d
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      gpu_mnist_example/inference.py

+ 3
- 2
gpu_mnist_example/inference.py View File

@@ -53,6 +53,7 @@ def test(model, test_loader, data_length):
test_loss /= (i+1)

# 结果写入输出文件夹
print('accuracy: {:.2f}'.format(correct / data_length))
filename = 'result.txt'
file_path = os.path.join('/tmp/output', filename)
with open(file_path, 'w') as file:
@@ -74,9 +75,9 @@ if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
epochs = args.epoch_size
test_dataset = mnist.MNIST(root=mnist_example_test2_model_djts_path + "/test", train=False, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=MnistDataset_torch + "/test", train=False, transform=ToTensor(),download=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model().to(device)
checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1_0.73.pkl")
checkpoint = torch.load(mnist_example_test2_model_djts_path + "/mnist_epoch1.pkl")
model.load_state_dict(checkpoint['model'])
test(model,test_loader,len(test_dataset))

Loading…
Cancel
Save