You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 2.6 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. GPU INFERENCE INSTANCE
  5. If there are Chinese comments in the code,please add at the beginning:
  6. #!/usr/bin/python
  7. #coding=utf-8
  8. Due to the adaptability of a100, please use the recommended image of the
  9. platform with cuda 11.Then adjust the code and submit the image.
  10. The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  11. In the environment, the uploaded dataset will be automatically placed in the /dataset directory.
  12. if MnistDataset_torch.zip is selected,Then the dataset directory is /dataset/test;
  13. The model file selected is in /model directory.
  14. The result download path is under /result . and the Qizhi platform will provide file downloads under the /result directory.
  15. 由于a100的适配性,请使用含cuda 11的平台镜像.
  16. 本例中的镜像是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  17. 选择的数据集被放置在/dataset目录
  18. 选择的模型文件放置在/model目录
  19. 输出结果路径是/result目录
  20. '''
  21. import numpy as np
  22. import torch
  23. from torchvision.datasets import mnist
  24. from torch.utils.data import DataLoader
  25. from torchvision.transforms import ToTensor
  26. import os
  27. import argparse
  28. # Training settings
  29. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  30. #获取模型文件名称
  31. parser.add_argument('--modelname', help='model name')
  32. if __name__ == '__main__':
  33. args, unknown = parser.parse_known_args()
  34. print('cuda is available:{}'.format(torch.cuda.is_available()))
  35. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  36. test_dataset = mnist.MNIST(root='/dataset/test', train=False, transform=ToTensor(),
  37. download=False)
  38. test_loader = DataLoader(test_dataset, batch_size=256)
  39. #如果文件名确定,model_path可以直接写死
  40. model_path = '/model/'+args.modelname
  41. model = torch.load(model_path).to(device)
  42. model.eval()
  43. correct = 0
  44. _sum = 0
  45. for idx, (test_x, test_label) in enumerate(test_loader):
  46. test_x = test_x
  47. test_label = test_label
  48. predict_y = model(test_x.to(device).float()).detach()
  49. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  50. label_np = test_label.numpy()
  51. _ = predict_ys == test_label
  52. correct += np.sum(_.numpy(), axis=-1)
  53. _sum += _.shape[0]
  54. print('accuracy: {:.2f}'.format(correct / _sum))
  55. #结果写入/result
  56. filename = 'result.txt'
  57. file_path = os.path.join('/result', filename)
  58. with open(file_path, 'w') as file:
  59. file.write('accuracy: {:.2f}'.format(correct / _sum))

No Description