You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference_for_multidataset.py 7.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. ######################## multi-dataset inference lenet example ########################
  3. This example is a single-dataset inference tutorial.
  4. ######################## Instructions for using the inference environment ########################
  5. 1、Inference task requires predefined functions
  6. (1)Copy multi dataset from obs to inference image.
  7. function MultiObsToEnv(obs_data_url, data_dir)
  8. (2)Copy ckpt file from obs to inference image.
  9. function ObsUrlToEnv(obs_ckpt_url, ckpt_url)
  10. (3)Copy the output result to obs.
  11. function EnvToObs(train_dir, obs_train_url)
  12. 3、5 parameters need to be defined.
  13. --data_url is the first dataset you selected on the Qizhi platform
  14. --multi_data_url is the multi dataset you selected on the Qizhi platform
  15. --ckpt_url is the weight file you choose on the Qizhi platform
  16. --result_url is the output
  17. --data_url,--multi_data_url,--ckpt_url,--result_url,--device_target,These 5 parameters must be defined first in a single dataset,
  18. otherwise an error will be reported.
  19. There is no need to add these parameters to the running parameters of the Qizhi platform,
  20. because they are predefined in the background, you only need to define them in your code.
  21. 4、How the dataset is used
  22. Multi-datasets use multi_data_url as input, data_dir + dataset name + file or folder name in the dataset as the
  23. calling path of the dataset in the inference image.
  24. For example, the calling path of the test folder in the MNIST_Data dataset in this example is
  25. data_dir + "/MNIST_Data" +"/test"
  26. For details, please refer to the following sample code.
  27. """
  28. import os
  29. import argparse
  30. import moxing as mox
  31. import mindspore.nn as nn
  32. from mindspore import context
  33. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  34. from mindspore.train import Model
  35. from mindspore.nn.metrics import Accuracy
  36. from mindspore import Tensor
  37. import numpy as np
  38. from glob import glob
  39. from dataset import create_dataset
  40. from config import mnist_cfg as cfg
  41. from lenet import LeNet5
  42. import json
  43. ### Copy multiple datasets from obs to inference image ###
  44. def MultiObsToEnv(multi_data_url, data_dir):
  45. #--multi_data_url is json data, need to do json parsing for multi_data_url
  46. multi_data_json = json.loads(multi_data_url)
  47. for i in range(len(multi_data_json)):
  48. path = data_dir + "/" + multi_data_json[i]["dataset_name"]
  49. if not os.path.exists(path):
  50. os.makedirs(path)
  51. try:
  52. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  53. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
  54. except Exception as e:
  55. print('moxing download {} to {} failed: '.format(
  56. multi_data_json[i]["dataset_url"], path) + str(e))
  57. return
  58. ### Copy ckpt file from obs to inference image###
  59. ### To operate on folders, use mox.file.copy_parallel. If copying a file.
  60. ### Please use mox.file.copy to operate the file, this operation is to operate the file
  61. def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
  62. try:
  63. mox.file.copy(obs_ckpt_url, ckpt_url)
  64. print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
  65. except Exception as e:
  66. print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
  67. return
  68. ### Copy the output result to obs###
  69. def EnvToObs(train_dir, obs_train_url):
  70. try:
  71. mox.file.copy_parallel(train_dir, obs_train_url)
  72. print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
  73. except Exception as e:
  74. print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
  75. return
  76. ### --data_url,--multi_data_url,--ckpt_url,--result_url,--device_target,These 5 parameters must be defined first in a multi dataset inference task,
  77. ### otherwise an error will be reported.
  78. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  79. ### because they are predefined in the background, you only need to define them in your code.
  80. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  81. parser.add_argument('--data_url',
  82. type=str,
  83. default= '/cache/data1/',
  84. help='path where the dataset is saved')
  85. parser.add_argument('--multi_data_url',
  86. type=str,
  87. default= '/cache/data/',
  88. help='path where the dataset is saved')
  89. parser.add_argument('--ckpt_url',
  90. help='model to save/load',
  91. default= '/cache/checkpoint.ckpt')
  92. parser.add_argument('--result_url',
  93. help='result folder to save/load',
  94. default= '/cache/result/')
  95. parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
  96. help='device where the code will be implemented (default: Ascend)')
  97. if __name__ == "__main__":
  98. args = parser.parse_args()
  99. ###Initialize the data and result directories in the inference image###
  100. data_dir = '/cache/data'
  101. result_dir = '/cache/result'
  102. ckpt_url = '/cache/checkpoint.ckpt'
  103. if not os.path.exists(data_dir):
  104. os.makedirs(data_dir)
  105. if not os.path.exists(result_dir):
  106. os.makedirs(result_dir)
  107. ###Copy multiple dataset from obs to inference image
  108. MultiObsToEnv(args.multi_data_url, data_dir)
  109. ###Copy ckpt file from obs to inference image
  110. ObsUrlToEnv(args.ckpt_url, ckpt_url)
  111. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
  112. network = LeNet5(cfg.num_classes)
  113. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  114. repeat_size = cfg.epoch_size
  115. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  116. model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  117. print("============== Starting Testing ==============")
  118. param_dict = load_checkpoint(os.path.join(ckpt_url))
  119. load_param_into_net(network, param_dict)
  120. ds_test = create_dataset(os.path.join(data_dir + "/MNISTData", "test"), batch_size=1).create_dict_iterator()
  121. data = next(ds_test)
  122. images = data["image"].asnumpy()
  123. labels = data["label"].asnumpy()
  124. print('Tensor:', Tensor(data['image']))
  125. output = model.predict(Tensor(data['image']))
  126. predicted = np.argmax(output.asnumpy(), axis=1)
  127. pred = np.argmax(output.asnumpy(), axis=1)
  128. print('predicted:', predicted)
  129. print('pred:', pred)
  130. print(f'Predicted: "{predicted[0]}", Actual: "{labels[0]}"')
  131. filename = 'result.txt'
  132. file_path = os.path.join(result_dir, filename)
  133. with open(file_path, 'a+') as file:
  134. file.write(" {}: {:.2f} \n".format("Predicted", predicted[0]))
  135. ###Copy result data from the local running environment back to obs,
  136. ###and download it in the inference task corresponding to the Qizhi platform
  137. EnvToObs(result_dir, args.result_url)