You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_multidataset.py 10 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. ######################## multi-dataset train lenet example ########################
  3. This example is a multi-dataset training tutorial. If it is a single dataset, please refer to the single dataset
  4. training tutorial train.py. This example cannot be used for a single dataset!
  5. """
  6. """
  7. ######################## Instructions for using the training environment ########################
  8. 1、(1)The structure of the dataset uploaded for multi-dataset training in this example
  9. MNISTData.zip
  10. ├── test
  11. └── train
  12. checkpoint_lenet-1_1875.zip
  13. ├── checkpoint_lenet-1_1875.ckpt
  14. (2)The dataset structure in the training image for multiple datasets in this example
  15. workroot
  16. ├── MNISTData
  17. | ├── test
  18. | └── train
  19. └── checkpoint_lenet-1_1875
  20. ├── checkpoint_lenet-1_1875.ckpt
  21. 2、Multi-dataset training requires predefined functions
  22. (1)Copy multi-dataset from obs to training image
  23. function MultiObsToEnv(multi_data_url, data_dir)
  24. (2)Copy the output to obs
  25. function EnvToObs(train_dir, obs_train_url)
  26. (2)Download the input from Qizhi And Init
  27. function DownloadFromQizhi(multi_data_url, data_dir)
  28. (2)Upload the output to Qizhi
  29. function UploadToQizhi(train_dir, obs_train_url)
  30. 3、4 parameters need to be defined
  31. --data_url is the first dataset you selected on the Qizhi platform
  32. --multi_data_url is the multi-dataset you selected on the Qizhi platform
  33. --data_url,--multi_data_url,--train_url,--device_target,These 4 parameters must be defined first in a multi-dataset task,
  34. otherwise an error will be reported.
  35. There is no need to add these parameters to the running parameters of the Qizhi platform,
  36. because they are predefined in the background, you only need to define them in your code
  37. 4、How the dataset is used
  38. Multi-datasets use multi_data_url as input, data_dir + dataset name + file or folder name in the dataset as the
  39. calling path of the dataset in the training image.
  40. For example, the calling path of the train folder in the MNIST_Data dataset in this example is
  41. data_dir + "/MNIST_Data" +"/train"
  42. For details, please refer to the following sample code.
  43. """
  44. import os
  45. import argparse
  46. import moxing as mox
  47. from config import mnist_cfg as cfg
  48. from dataset import create_dataset
  49. from dataset_distributed import create_dataset_parallel
  50. from lenet import LeNet5
  51. import json
  52. import mindspore.nn as nn
  53. from mindspore import context
  54. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  55. from mindspore.train import Model
  56. from mindspore.nn.metrics import Accuracy
  57. from mindspore import load_checkpoint, load_param_into_net
  58. from mindspore.context import ParallelMode
  59. from mindspore.communication.management import init, get_rank
  60. import time
  61. ### Copy multiple datasets from obs to training image ###
  62. def MultiObsToEnv(multi_data_url, data_dir):
  63. #--multi_data_url is json data, need to do json parsing for multi_data_url
  64. multi_data_json = json.loads(multi_data_url)
  65. for i in range(len(multi_data_json)):
  66. path = data_dir + "/" + multi_data_json[i]["dataset_name"]
  67. if not os.path.exists(path):
  68. os.makedirs(path)
  69. try:
  70. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  71. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
  72. except Exception as e:
  73. print('moxing download {} to {} failed: '.format(
  74. multi_data_json[i]["dataset_url"], path) + str(e))
  75. #Set a cache file to determine whether the data has been copied to obs.
  76. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  77. f = open("/cache/download_input.txt", 'w')
  78. f.close()
  79. try:
  80. if os.path.exists("/cache/download_input.txt"):
  81. print("download_input succeed")
  82. except Exception as e:
  83. print("download_input failed")
  84. return
  85. ### Copy the output model to obs ###
  86. def EnvToObs(train_dir, obs_train_url):
  87. try:
  88. mox.file.copy_parallel(train_dir, obs_train_url)
  89. print("Successfully Upload {} to {}".format(train_dir,
  90. obs_train_url))
  91. except Exception as e:
  92. print('moxing upload {} to {} failed: '.format(train_dir,
  93. obs_train_url) + str(e))
  94. return
  95. def DownloadFromQizhi(multi_data_url, data_dir):
  96. device_num = int(os.getenv('RANK_SIZE'))
  97. if device_num == 1:
  98. MultiObsToEnv(multi_data_url,data_dir)
  99. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  100. if device_num > 1:
  101. # set device_id and init for multi-card training
  102. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  103. context.reset_auto_parallel_context()
  104. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  105. init()
  106. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  107. local_rank=int(os.getenv('RANK_ID'))
  108. if local_rank%8==0:
  109. MultiObsToEnv(multi_data_url,data_dir)
  110. #If the cache file does not exist, it means that the copy data has not been completed,
  111. #and Wait for 0th card to finish copying data
  112. while not os.path.exists("/cache/download_input.txt"):
  113. time.sleep(1)
  114. return
  115. def UploadToQizhi(train_dir, obs_train_url):
  116. device_num = int(os.getenv('RANK_SIZE'))
  117. local_rank=int(os.getenv('RANK_ID'))
  118. if device_num == 1:
  119. EnvToObs(train_dir, obs_train_url)
  120. if device_num > 1:
  121. if local_rank%8==0:
  122. EnvToObs(train_dir, obs_train_url)
  123. return
  124. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  125. ### --data_url,--multi_data_url,--train_url,--device_target,These 4 parameters must be defined first in a multi-dataset,
  126. ### otherwise an error will be reported.
  127. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  128. ### because they are predefined in the background, you only need to define them in your code.
  129. parser.add_argument('--data_url',
  130. help='path to training/inference dataset folder',
  131. default= '/cache/data1/')
  132. parser.add_argument('--multi_data_url',
  133. help='path to multi dataset',
  134. default= '/cache/data/')
  135. parser.add_argument('--train_url',
  136. help='model folder to save/load',
  137. default= '/cache/output/')
  138. parser.add_argument(
  139. '--device_target',
  140. type=str,
  141. default="Ascend",
  142. choices=['Ascend', 'CPU'],
  143. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  144. parser.add_argument('--epoch_size',
  145. type=int,
  146. default=5,
  147. help='Training epochs.')
  148. if __name__ == "__main__":
  149. args = parser.parse_args()
  150. data_dir = '/cache/data'
  151. train_dir = '/cache/output'
  152. if not os.path.exists(data_dir):
  153. os.makedirs(data_dir)
  154. if not os.path.exists(train_dir):
  155. os.makedirs(train_dir)
  156. ###Initialize and copy data to training image
  157. DownloadFromQizhi(args.multi_data_url, data_dir)
  158. ###The dataset path is used here:data_dir + "/MNIST_Data" +"/train"
  159. device_num = int(os.getenv('RANK_SIZE'))
  160. if device_num == 1:
  161. ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  162. if device_num > 1:
  163. ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  164. if ds_train.get_dataset_size() == 0:
  165. raise ValueError(
  166. "Please check dataset size > 0 and batch_size <= dataset size")
  167. network = LeNet5(cfg.num_classes)
  168. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  169. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  170. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  171. ###The dataset path is used here:data_dir + "/checkpoint_lenet-1_1875"+"/checkpoint_lenet-1_1875.ckpt"
  172. load_param_into_net(network, load_checkpoint(os.path.join(data_dir + "/checkpoint_lenet-1_1875",
  173. "checkpoint_lenet-1_1875.ckpt")))
  174. if args.device_target != "Ascend":
  175. model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()})
  176. else:
  177. model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2")
  178. config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
  179. keep_checkpoint_max=cfg.keep_checkpoint_max)
  180. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  181. # In this example, get_rank() is added to distinguish different paths.
  182. if device_num == 1:
  183. outputDirectory = train_dir + "/"
  184. if device_num > 1:
  185. outputDirectory = train_dir + "/" + str(get_rank()) + "/"
  186. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  187. directory=outputDirectory,
  188. config=config_ck)
  189. print("============== Starting Training ==============")
  190. epoch_size = cfg['epoch_size']
  191. if (args.epoch_size):
  192. epoch_size = args.epoch_size
  193. print('epoch_size is: ', epoch_size)
  194. model.train(epoch_size,
  195. ds_train,
  196. callbacks=[time_cb, ckpoint_cb,
  197. LossMonitor()])
  198. ###Copy the trained output data from the local running environment back to obs,
  199. ###and download it in the training task corresponding to the Qizhi platform
  200. UploadToQizhi(train_dir,args.train_url)