You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

c2net_listdata.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """
  2. ######################## train lenet example ########################
  3. train lenet and get network model files(.ckpt)
  4. """
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. import os
  8. import argparse
  9. import moxing as mox
  10. from config import mnist_cfg as cfg
  11. from dataset import create_dataset
  12. from dataset_distributed import create_dataset_parallel
  13. from lenet import LeNet5
  14. import json
  15. import mindspore.nn as nn
  16. from mindspore import context
  17. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  18. from mindspore.train import Model
  19. from mindspore.nn.metrics import Accuracy
  20. from mindspore import load_checkpoint, load_param_into_net
  21. from mindspore.context import ParallelMode
  22. from mindspore.communication.management import init, get_rank
  23. import time
  24. ### Copy multiple datasets from obs to training image ###
  25. def MultiObsToEnv(multi_data_url, data_dir):
  26. #--multi_data_url is json data, need to do json parsing for multi_data_url
  27. multi_data_json = json.loads(multi_data_url)
  28. for i in range(len(multi_data_json)):
  29. path = data_dir + "/" + multi_data_json[i]["dataset_name"]
  30. file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0]
  31. if not os.path.exists(file_path):
  32. os.makedirs(file_path)
  33. try:
  34. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  35. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
  36. #unzip dataset
  37. os.system("unzip -d %s %s" % (file_path, path))
  38. except Exception as e:
  39. print('moxing download {} to {} failed: '.format(
  40. multi_data_json[i]["dataset_url"], path) + str(e))
  41. #Set a cache file to determine whether the data has been copied to obs.
  42. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  43. f = open("/cache/download_input.txt", 'w')
  44. f.close()
  45. try:
  46. if os.path.exists("/cache/download_input.txt"):
  47. print("download_input succeed")
  48. except Exception as e:
  49. print("download_input failed")
  50. return
  51. def DownloadFromQizhi(multi_data_url, data_dir):
  52. device_num = int(os.getenv('RANK_SIZE'))
  53. if device_num == 1:
  54. MultiObsToEnv(multi_data_url,data_dir)
  55. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  56. if device_num > 1:
  57. # set device_id and init for multi-card training
  58. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  59. context.reset_auto_parallel_context()
  60. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  61. init()
  62. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  63. local_rank=int(os.getenv('RANK_ID'))
  64. if local_rank%8==0:
  65. MultiObsToEnv(multi_data_url,data_dir)
  66. #If the cache file does not exist, it means that the copy data has not been completed,
  67. #and Wait for 0th card to finish copying data
  68. while not os.path.exists("/cache/download_input.txt"):
  69. time.sleep(1)
  70. return
  71. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  72. ### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset,
  73. ### otherwise an error will be reported.
  74. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  75. ### because they are predefined in the background, you only need to define them in your code.
  76. parser.add_argument('--multi_data_url',
  77. help='dataset path in obs')
  78. parser.add_argument('--ckpt_url',
  79. help='pre_train_model path in obs')
  80. parser.add_argument(
  81. '--device_target',
  82. type=str,
  83. default="Ascend",
  84. choices=['Ascend', 'CPU'],
  85. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  86. parser.add_argument('--epoch_size',
  87. type=int,
  88. default=5,
  89. help='Training epochs.')
  90. if __name__ == "__main__":
  91. args, unknown = parser.parse_known_args()
  92. data_dir = '/cache/dataset'
  93. train_dir = '/cache/output'
  94. if not os.path.exists(data_dir):
  95. os.makedirs(data_dir)
  96. if not os.path.exists(train_dir):
  97. os.makedirs(train_dir)
  98. ###Initialize and copy data to training image
  99. DownloadFromQizhi(args.multi_data_url, data_dir)
  100. print("--------start ls:")
  101. os.system("cd /cache/dataset; ls -al")
  102. print("--------end ls-----------")