You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 8.9 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. """
  2. ######################## single-dataset inference lenet example ########################
  3. This example is a single-dataset inference tutorial.
  4. ######################## Instructions for using the inference environment ########################
  5. The image of the debugging environment and the image of the inference environment are two different images,
  6. and the working local directories are different. In the inference task, you need to pay attention to the following points.
  7. 1、(1)The structure of the dataset uploaded for single dataset inference in this example
  8. MNISTData.zip
  9. ├── test
  10. │ ├── t10k-images-idx3-ubyte
  11. │ └── t10k-labels-idx1-ubyte
  12. └── train
  13. ├── train-images-idx3-ubyte
  14. └── train-labels-idx1-ubyte
  15. (2)The dataset structure of the single dataset in the inference image in this example
  16. workroot
  17. ├── data
  18. | ├── test
  19. | └── train
  20. 2、Inference task requires predefined functions
  21. (1)Defines whether the task is a inference environment or a debugging environment.
  22. def WorkEnvironment(environment):
  23. if environment == 'train':
  24. workroot = '/home/work/user-job-dir' #The inference task uses this parameter to represent the local path of the inference image
  25. elif environment == 'debug':
  26. workroot = '/home/ma-user/work' #The debug task uses this parameter to represent the local path of the debug image
  27. print('current work mode:' + environment + ', workroot:' + workroot)
  28. return workroot
  29. (2)Copy single dataset from obs to inference image.
  30. def ObsToEnv(obs_data_url, data_dir):
  31. try:
  32. mox.file.copy_parallel(obs_data_url, data_dir)
  33. print("Successfully Download {} to {}".format(obs_data_url, data_dir))
  34. except Exception as e:
  35. print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
  36. return
  37. (3)Copy ckpt file from obs to inference image.
  38. def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
  39. try:
  40. mox.file.copy(obs_ckpt_url, ckpt_url)
  41. print("Successfully Download {} to {}".format(obs_ckpt_url,
  42. ckpt_url))
  43. except Exception as e:
  44. print('moxing download {} to {} failed: '.format(
  45. obs_ckpt_url, ckpt_url) + str(e))
  46. return
  47. (4)Copy the output result to obs.
  48. def EnvToObs(train_dir, obs_train_url):
  49. try:
  50. mox.file.copy_parallel(train_dir, obs_train_url)
  51. print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
  52. except Exception as e:
  53. print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
  54. return
  55. 3、4 parameters need to be defined.
  56. --data_url is the dataset you selected on the Qizhi platform
  57. --ckpt_url is the weight file you choose on the Qizhi platform
  58. --data_url,--ckpt_url,--result_url,--device_target,These 4 parameters must be defined first in a single dataset,
  59. otherwise an error will be reported.
  60. There is no need to add these parameters to the running parameters of the Qizhi platform,
  61. because they are predefined in the background, you only need to define them in your code.
  62. 4、How the dataset is used
  63. Inference task uses data_url as the input, and data_dir (ie: workroot + '/data') as the calling method
  64. of the dataset in the image.
  65. For details, please refer to the following sample code.
  66. """
  67. import os
  68. import argparse
  69. import moxing as mox
  70. import mindspore.nn as nn
  71. from mindspore import context
  72. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  73. from mindspore.train import Model
  74. from mindspore.nn.metrics import Accuracy
  75. from mindspore import Tensor
  76. import numpy as np
  77. from glob import glob
  78. from dataset import create_dataset
  79. from config import mnist_cfg as cfg
  80. from lenet import LeNet5
  81. ### Defines whether the task is a inference environment or a debugging environment ###
  82. def WorkEnvironment(environment):
  83. if environment == 'train':
  84. workroot = '/home/work/user-job-dir'
  85. elif environment == 'debug':
  86. workroot = '/home/work'
  87. print('current work mode:' + environment + ', workroot:' + workroot)
  88. return workroot
  89. ### Copy single dataset from obs to inference image ###
  90. def ObsToEnv(obs_data_url, data_dir):
  91. try:
  92. mox.file.copy_parallel(obs_data_url, data_dir)
  93. print("Successfully Download {} to {}".format(obs_data_url, data_dir))
  94. except Exception as e:
  95. print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
  96. return
  97. ### Copy ckpt file from obs to inference image###
  98. ### To operate on folders, use mox.file.copy_parallel. If copying a file.
  99. ### Please use mox.file.copy to operate the file, this operation is to operate the file
  100. def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
  101. try:
  102. mox.file.copy(obs_ckpt_url, ckpt_url)
  103. print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
  104. except Exception as e:
  105. print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
  106. return
  107. ### Copy the output result to obs###
  108. def EnvToObs(train_dir, obs_train_url):
  109. try:
  110. mox.file.copy_parallel(train_dir, obs_train_url)
  111. print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
  112. except Exception as e:
  113. print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
  114. return
  115. ### --data_url,--ckpt_url,--result_url,--device_target,These 4 parameters must be defined first in a inference task,
  116. ### otherwise an error will be reported.
  117. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  118. ### because they are predefined in the background, you only need to define them in your code.
  119. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  120. parser.add_argument('--data_url',
  121. type=str,
  122. default= WorkEnvironment('train') + '/data/',
  123. help='path where the dataset is saved')
  124. parser.add_argument('--ckpt_url',
  125. help='model to save/load',
  126. default= WorkEnvironment('train') + '/checkpoint.ckpt')
  127. parser.add_argument('--result_url',
  128. help='result folder to save/load',
  129. default= WorkEnvironment('train') + '/result/')
  130. parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
  131. help='device where the code will be implemented (default: Ascend)')
  132. if __name__ == "__main__":
  133. args, unknown = parser.parse_known_args()
  134. ### defining the training environment
  135. environment = 'train'
  136. workroot = WorkEnvironment(environment)
  137. ###Initialize the data and result directories in the inference image###
  138. data_dir = workroot + '/data'
  139. result_dir = workroot + '/result'
  140. ckpt_url = workroot + '/checkpoint.ckpt'
  141. if not os.path.exists(data_dir):
  142. os.makedirs(data_dir)
  143. if not os.path.exists(result_dir):
  144. os.makedirs(result_dir)
  145. ###Copy dataset from obs to inference image
  146. obs_data_url = args.data_url
  147. ObsToEnv(obs_data_url, data_dir)
  148. ###Copy ckpt file from obs to inference image
  149. obs_ckpt_url = args.ckpt_url
  150. ObsUrlToEnv(obs_ckpt_url, ckpt_url)
  151. ###Set output path result_url
  152. obs_result_url = args.result_url
  153. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  154. network = LeNet5(cfg.num_classes)
  155. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  156. repeat_size = cfg.epoch_size
  157. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  158. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  159. print("============== Starting Testing ==============")
  160. param_dict = load_checkpoint(os.path.join(ckpt_url))
  161. load_param_into_net(network, param_dict)
  162. ds_test = create_dataset(os.path.join(data_dir, "test"), batch_size=1).create_dict_iterator()
  163. data = next(ds_test)
  164. images = data["image"].asnumpy()
  165. labels = data["label"].asnumpy()
  166. print('Tensor:', Tensor(data['image']))
  167. output = model.predict(Tensor(data['image']))
  168. predicted = np.argmax(output.asnumpy(), axis=1)
  169. pred = np.argmax(output.asnumpy(), axis=1)
  170. print('predicted:', predicted)
  171. print('pred:', pred)
  172. print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
  173. filename = 'result.txt'
  174. file_path = os.path.join(result_dir, filename)
  175. with open(file_path, 'a+') as file:
  176. file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
  177. ###Copy result data from the local running environment back to obs,
  178. ###and download it in the inference task corresponding to the Qizhi platform
  179. EnvToObs(result_dir, obs_result_url)