You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_multidataset_dataparallel.py 11 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """
  2. ######################## multi-dataset train lenet example ########################
  3. This example is a multi-dataset training tutorial. If it is a single dataset, please refer to the single dataset
  4. training tutorial train.py. This example cannot be used for a single dataset!
  5. """
  6. """
  7. ######################## Instructions for using the training environment ########################
  8. 1、(1)The structure of the dataset uploaded for multi-dataset training in this example
  9. MNISTData.zip
  10. ├── test
  11. │ ├── t10k-images-idx3-ubyte
  12. │ └── t10k-labels-idx1-ubyte
  13. └── train
  14. ├── train-images-idx3-ubyte
  15. └── train-labels-idx1-ubyte
  16. checkpoint_lenet-1_1875.zip
  17. ├── checkpoint_lenet-1_1875.ckpt
  18. (2)The dataset structure in the training image for multiple datasets in this example
  19. workroot
  20. ├── MNISTData
  21. | ├── test
  22. | └── train
  23. └── checkpoint_lenet-1_1875
  24. ├── checkpoint_lenet-1_1875.ckpt
  25. 2、Multi-dataset training requires predefined functions
  26. (1)Defines whether the task is a training environment or a debugging environment.
  27. def WorkEnvironment(environment):
  28. if environment == 'train':
  29. workroot = '/home/work/user-job-dir' #The training task uses this parameter to represent the local path of the training image
  30. elif environment == 'debug':
  31. workroot = '/home/ma-user/work' #The debug task uses this parameter to represent the local path of the debug image
  32. print('current work mode:' + environment + ', workroot:' + workroot)
  33. return workroot
  34. (2)Copy multiple datasets from obs to training image
  35. def MultiObsToEnv(multi_data_url, workroot):
  36. multi_data_json = json.loads(multi_data_url) #Parse multi_data_url
  37. for i in range(len(multi_data_json)):
  38. path = workroot + "/" + multi_data_json[i]["dataset_name"]
  39. if not os.path.exists(path):
  40. os.makedirs(path)
  41. try:
  42. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  43. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],
  44. path))
  45. except Exception as e:
  46. print('moxing download {} to {} failed: '.format(
  47. multi_data_json[i]["dataset_url"], path) + str(e))
  48. return
  49. ***The input and output of the MultiObsToEnv function in this example:
  50. Input for multi_data_url:
  51. [
  52. {
  53. "dataset_url": "s3://test-opendata/attachment/e/a/eae3a316-42d6-4a43-a484-1fa573eab388e
  54. ae3a316-42d6-4a43-a484-1fa573eab388/", #obs path of the dataset
  55. "dataset_name": "MNIST_Data" #the name of the dataset
  56. },
  57. {
  58. "dataset_url": "s3://test-opendata/attachment/2/c/2c59be66-64ec-41ca-b311-f51a486eabf82c
  59. 59be66-64ec-41ca-b311-f51a486eabf8/",
  60. "dataset_name": "checkpoint_lenet-1_1875"
  61. }
  62. ]
  63. Purpose of multi_data_url:
  64. The purpose of the MultiObsToEnv function is to copy multiple datasets from obs to the training image
  65. and build the dataset path in the training image.
  66. For example, the path of the MNIST_Data dataset in this example is /home/work/user-job-dir/MNISTData,
  67. The path to the checkpoint_lenet-1_1875 dataset is /home/work/user-job-dir/checkpoint_lenet-1_1875
  68. (3)Copy the output model to obs.
  69. def EnvToObs(train_dir, obs_train_url):
  70. try:
  71. mox.file.copy_parallel(train_dir, obs_train_url)
  72. print("Successfully Upload {} to {}".format(train_dir,
  73. obs_train_url))
  74. except Exception as e:
  75. print('moxing upload {} to {} failed: '.format(train_dir,
  76. obs_train_url) + str(e))
  77. return
  78. 3、4 parameters need to be defined
  79. --data_url is the first dataset you selected on the Qizhi platform
  80. --multi_data_url is the multi-dataset you selected on the Qizhi platform
  81. --data_url,--multi_data_url,--train_url,--device_target,These 4 parameters must be defined first in a multi-dataset task,
  82. otherwise an error will be reported.
  83. There is no need to add these parameters to the running parameters of the Qizhi platform,
  84. because they are predefined in the background, you only need to define them in your code
  85. 4、How the dataset is used
  86. Multi-datasets use multi_data_url as input, workroot + dataset name + file or folder name in the dataset as the
  87. calling path of the dataset in the training image.
  88. For example, the calling path of the train folder in the MNIST_Data dataset in this example is
  89. workroot + "/MNIST_Data" +"/train"
  90. For details, please refer to the following sample code.
  91. """
  92. import os
  93. import argparse
  94. import moxing as mox
  95. from config import mnist_cfg as cfg
  96. from dataset_distributed import create_dataset_parallel
  97. from dataset import create_dataset
  98. from lenet import LeNet5
  99. import json
  100. import mindspore.nn as nn
  101. from mindspore import context
  102. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  103. from mindspore.train import Model
  104. from mindspore.nn.metrics import Accuracy
  105. from mindspore.common import set_seed
  106. from mindspore import load_checkpoint, load_param_into_net
  107. from mindspore.context import ParallelMode
  108. from mindspore.communication.management import init, get_rank, get_group_size
  109. import mindspore.ops as ops
  110. # set device_id and init
  111. device_id = int(os.getenv('ASCEND_DEVICE_ID'))
  112. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  113. context.set_context(device_id=device_id)
  114. init()
  115. ### Defines whether the task is a training environment or a debugging environment ###
  116. def WorkEnvironment(environment):
  117. if environment == 'train':
  118. workroot = '/home/work/user-job-dir'
  119. elif environment == 'debug':
  120. workroot = '/home/ma-user/work'
  121. print('current work mode:' + environment + ', workroot:' + workroot)
  122. return workroot
  123. ### Copy multiple datasets from obs to training image ###
  124. def MultiObsToEnv(multi_data_url, workroot):
  125. multi_data_json = json.loads(multi_data_url)
  126. for i in range(len(multi_data_json)):
  127. path = workroot + "/" + multi_data_json[i]["dataset_name"]
  128. if not os.path.exists(path):
  129. os.makedirs(path)
  130. try:
  131. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  132. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],
  133. path))
  134. except Exception as e:
  135. print('moxing download {} to {} failed: '.format(
  136. multi_data_json[i]["dataset_url"], path) + str(e))
  137. return
  138. ### Copy the output model to obs ###
  139. def EnvToObs(train_dir, obs_train_url):
  140. try:
  141. mox.file.copy_parallel(train_dir, obs_train_url)
  142. print("Successfully Upload {} to {}".format(train_dir,
  143. obs_train_url))
  144. except Exception as e:
  145. print('moxing upload {} to {} failed: '.format(train_dir,
  146. obs_train_url) + str(e))
  147. return
  148. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  149. ### --data_url,--multi_data_url,--train_url,--device_target,These 4 parameters must be defined first in a multi-dataset,
  150. ### otherwise an error will be reported.
  151. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  152. ### because they are predefined in the background, you only need to define them in your code.
  153. parser.add_argument('--data_url',
  154. help='path to training/inference dataset folder',
  155. default= WorkEnvironment('train') + '/data/')
  156. parser.add_argument('--multi_data_url',
  157. help='path to multi dataset',
  158. default= WorkEnvironment('train'))
  159. parser.add_argument('--train_url',
  160. help='model folder to save/load',
  161. default= WorkEnvironment('train') + '/model/')
  162. parser.add_argument(
  163. '--device_target',
  164. type=str,
  165. default="Ascend",
  166. choices=['Ascend', 'CPU'],
  167. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  168. parser.add_argument('--epoch_size',
  169. type=int,
  170. default=5,
  171. help='Training epochs.')
  172. set_seed(114514)
  173. if __name__ == "__main__":
  174. args = parser.parse_args()
  175. # After defining the training environment, first execute the WorkEnv function and the GetMultiDataPath function to
  176. # copy multiple datasets from obs to the training image
  177. environment = 'train'
  178. workroot = WorkEnvironment(environment)
  179. MultiObsToEnv(args.multi_data_url, workroot)
  180. ### Define the output path in the training image
  181. train_dir = workroot + '/model'
  182. if not os.path.exists(train_dir):
  183. os.makedirs(train_dir)
  184. ### Copy the dataset from obs to the training image ###
  185. context.reset_auto_parallel_context()
  186. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
  187. ds_train = create_dataset_parallel(os.path.join(workroot + "/MNISTData", "train"),
  188. cfg.batch_size)
  189. if ds_train.get_dataset_size() == 0:
  190. raise ValueError(
  191. "Please check dataset size > 0 and batch_size <= dataset size")
  192. network = LeNet5(cfg.num_classes)
  193. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  194. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  195. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  196. ### Load the trained model:workroot + "/checkpoint_lenet-1_1875"+"/checkpoint_lenet-1_1875.ckpt"
  197. load_param_into_net(network, load_checkpoint(os.path.join(workroot + "/checkpoint_lenet-1_1875",
  198. "checkpoint_lenet-1_1875.ckpt")))
  199. if args.device_target != "Ascend":
  200. model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()})
  201. else:
  202. model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2")
  203. config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
  204. keep_checkpoint_max=cfg.keep_checkpoint_max)
  205. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  206. # In the example, get_rank() is added to distinguish different paths.
  207. ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
  208. directory=train_dir + "/" + str(get_rank()) + "/",
  209. config=config_ck)
  210. print("============== Starting Training ==============")
  211. epoch_size = cfg['epoch_size']
  212. if (args.epoch_size):
  213. epoch_size = args.epoch_size
  214. print('epoch_size is: ', epoch_size)
  215. model.train(epoch_size,
  216. ds_train,
  217. callbacks=[time_cb, ckpoint_cb,
  218. LossMonitor()])
  219. ###Copy the trained model data from the local running environment back to obs,
  220. ###and download it in the training task corresponding to the Qizhi platform
  221. EnvToObs(train_dir, args.train_url)