You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 7.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """
  2. ######################## inference lenet example ########################
  3. inference lenet according to model file
  4. """
  5. """
  6. ######################## 推理环境使用说明 ########################
  7. 1、在推理环境中,需要将数据集从obs拷贝到推理镜像中,推理完以后,需要将输出的结果拷贝到obs.
  8. (1)将数据集从obs拷贝到推理镜像中:
  9. obs_data_url = args.data_url
  10. args.data_url = '/home/work/user-job-dir/data/'
  11. if not os.path.exists(args.data_url):
  12. os.mkdir(args.data_url)
  13. try:
  14. mox.file.copy_parallel(obs_data_url, args.data_url)
  15. print("Successfully Download {} to {}".format(obs_data_url,
  16. args.data_url))
  17. except Exception as e:
  18. print('moxing download {} to {} failed: '.format(
  19. obs_data_url, args.data_url) + str(e))
  20. (2)将模型文件从obs拷贝到推理镜像中:
  21. obs_ckpt_url = args.ckpt_url
  22. args.ckpt_url = '/home/work/user-job-dir/checkpoint.ckpt'
  23. try:
  24. mox.file.copy(obs_ckpt_url, args.ckpt_url)
  25. print("Successfully Download {} to {}".format(obs_ckpt_url,
  26. args.ckpt_url))
  27. except Exception as e:
  28. print('moxing download {} to {} failed: '.format(
  29. obs_ckpt_url, args.ckpt_url) + str(e))
  30. (3)将输出的结果拷贝回obs:
  31. obs_result_url = args.result_url
  32. args.result_url = '/home/work/user-job-dir/result/'
  33. if not os.path.exists(args.result_url):
  34. os.mkdir(args.result_url)
  35. try:
  36. mox.file.copy_parallel(args.result_url, obs_result_url)
  37. print("Successfully Upload {} to {}".format(args.result_url, obs_result_url))
  38. except Exception as e:
  39. print('moxing upload {} to {} failed: '.format(args.result_url, obs_result_url) + str(e))
  40. 详细代码可参考以下示例代码:
  41. """
  42. import os
  43. import argparse
  44. import moxing as mox
  45. import mindspore.nn as nn
  46. from mindspore import context
  47. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  48. from mindspore.train import Model
  49. from mindspore.nn.metrics import Accuracy
  50. from mindspore import Tensor
  51. import numpy as np
  52. from glob import glob
  53. from dataset import create_dataset
  54. from config import mnist_cfg as cfg
  55. from lenet import LeNet5
  56. if __name__ == "__main__":
  57. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  58. parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
  59. help='device where the code will be implemented (default: Ascend)')
  60. parser.add_argument('--data_url',
  61. type=str,
  62. default="./Data",
  63. help='path where the dataset is saved')
  64. parser.add_argument('--ckpt_url',
  65. help='model to save/load',
  66. default='./ckpt_url')
  67. parser.add_argument('--result_url',
  68. help='result folder to save/load',
  69. default='./result')
  70. args = parser.parse_args()
  71. #将数据集从obs拷贝到推理镜像中:
  72. obs_data_url = args.data_url
  73. args.data_url = '/home/work/user-job-dir/data/'
  74. if not os.path.exists(args.data_url):
  75. os.mkdir(args.data_url)
  76. try:
  77. mox.file.copy_parallel(obs_data_url, args.data_url)
  78. print("Successfully Download {} to {}".format(obs_data_url,
  79. args.data_url))
  80. except Exception as e:
  81. print('moxing download {} to {} failed: '.format(
  82. obs_data_url, args.data_url) + str(e))
  83. #对文件夹进行操作,请使用mox.file.copy_parallel。如果拷贝一个文件。请使用mox.file.copy对文件操作,本次操作是对文件进行操作
  84. #将模型文件从obs拷贝到推理镜像中:
  85. obs_ckpt_url = args.ckpt_url
  86. args.ckpt_url = '/home/work/user-job-dir/checkpoint.ckpt'
  87. try:
  88. mox.file.copy(obs_ckpt_url, args.ckpt_url)
  89. print("Successfully Download {} to {}".format(obs_ckpt_url,
  90. args.ckpt_url))
  91. except Exception as e:
  92. print('moxing download {} to {} failed: '.format(
  93. obs_ckpt_url, args.ckpt_url) + str(e))
  94. #设置输出路径result_url
  95. obs_result_url = args.result_url
  96. args.result_url = '/home/work/user-job-dir/result/'
  97. if not os.path.exists(args.result_url):
  98. os.mkdir(args.result_url)
  99. args.dataset_path = args.data_url
  100. args.save_checkpoint_path = args.ckpt_url
  101. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  102. network = LeNet5(cfg.num_classes)
  103. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  104. repeat_size = cfg.epoch_size
  105. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  106. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  107. print("============== Starting Testing ==============")
  108. args.load_ckpt_url = os.path.join(args.save_checkpoint_path)
  109. print("args.load_ckpt_url is:{}", args.load_ckpt_url )
  110. param_dict = load_checkpoint(args.load_ckpt_url )
  111. load_param_into_net(network, param_dict)
  112. # 定义测试数据集,batch_size设置为1,则取出一张图片
  113. ds_test = create_dataset(os.path.join(args.dataset_path, "test"), batch_size=1).create_dict_iterator()
  114. data = next(ds_test)
  115. # images为测试图片,labels为测试图片的实际分类
  116. images = data["image"].asnumpy()
  117. labels = data["label"].asnumpy()
  118. print('Tensor:', Tensor(data['image']))
  119. # 使用函数model.predict预测image对应分类
  120. output = model.predict(Tensor(data['image']))
  121. predicted = np.argmax(output.asnumpy(), axis=1)
  122. pred = np.argmax(output.asnumpy(), axis=1)
  123. print('predicted:', predicted)
  124. print('pred:', pred)
  125. # 输出预测分类与实际分类,并输出到result_url
  126. print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
  127. filename = 'result.txt'
  128. file_path = os.path.join(args.result_url, filename)
  129. with open(file_path, 'a+') as file:
  130. file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
  131. # Upload results to obs
  132. ######################## 将输出的结果拷贝到obs(固定写法) ########################
  133. # 把推理后的结果从本地的运行环境拷贝回obs,在启智平台相对应的推理任务中会提供下载
  134. try:
  135. mox.file.copy_parallel(args.result_url, obs_result_url)
  136. print("Successfully Upload {} to {}".format(args.result_url, obs_result_url))
  137. except Exception as e:
  138. print('moxing upload {} to {} failed: '.format(args.result_url, obs_result_url) + str(e))
  139. ######################## 将输出的模型拷贝到obs ########################