You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_inference.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. GPU INFERENCE INSTANCE
  5. If there are Chinese comments in the code,please add at the beginning:
  6. #!/usr/bin/python
  7. #coding=utf-8
  8. Due to the adaptability of a100, please use the recommended image of the
  9. platform with cuda 11.Then adjust the code and submit the image.
  10. The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  11. In the environment, the uploaded dataset will be automatically placed in the /dataset directory.
  12. if MnistDataset_torch.zip is selected,Then the dataset directory is /dataset/test;
  13. The model file selected is in /model directory.
  14. The result download path is under /result . and the Qizhi platform will provide file downloads under the /result directory.
  15. 本例中的镜像是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  16. 选择的数据集被放置在/dataset目录
  17. 选择的模型文件放置在/model目录
  18. 输出结果路径是/result目录
  19. !!!注意:目前推理的资源环境不支持联网,所以镜像无法使用公网镜像,镜像必须先提交到启智平台;推理的数据集也需要先上传到启智平台
  20. '''
  21. import numpy as np
  22. import torch
  23. from torchvision.datasets import mnist
  24. from torch.utils.data import DataLoader
  25. from torchvision.transforms import ToTensor
  26. import os
  27. import argparse
  28. from model import Model
  29. # Training settings
  30. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  31. #获取模型文件名称
  32. parser.add_argument('--modelname', help='model name')
  33. if __name__ == '__main__':
  34. args, unknown = parser.parse_known_args()
  35. print('cuda is available:{}'.format(torch.cuda.is_available()))
  36. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  37. test_dataset = mnist.MNIST(root='/dataset/test', train=False, transform=ToTensor(),
  38. download=False)
  39. test_loader = DataLoader(test_dataset, batch_size=256)
  40. #如果文件名确定,model_path可以直接写死
  41. model_path = '/model/'+args.modelname
  42. model = Model().to(device)
  43. checkpoint = torch.load(model_path)
  44. model.load_state_dict(checkpoint['model'])
  45. model.eval()
  46. correct = 0
  47. _sum = 0
  48. for idx, (test_x, test_label) in enumerate(test_loader):
  49. test_x = test_x
  50. test_label = test_label
  51. predict_y = model(test_x.to(device).float()).detach()
  52. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  53. label_np = test_label.numpy()
  54. _ = predict_ys == test_label
  55. correct += np.sum(_.numpy(), axis=-1)
  56. _sum += _.shape[0]
  57. print('accuracy: {:.2f}'.format(correct / _sum))
  58. #结果写入/result
  59. filename = 'result.txt'
  60. file_path = os.path.join('/result', filename)
  61. with open(file_path, 'w') as file:
  62. file.write('accuracy: {:.2f}'.format(correct / _sum))