You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_dataparallel.py 9.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """
  2. ######################## single-dataset train lenet example ########################
  3. This example is a single-dataset training tutorial. If it is a multi-dataset, please refer to the multi-dataset training
  4. tutorial train_for_multidataset.py. This example cannot be used for multi-datasets!
  5. ######################## Instructions for using the training environment ########################
  6. The image of the debugging environment and the image of the training environment are two different images,
  7. and the working local directories are different. In the training task, you need to pay attention to the following points.
  8. 1、(1)The structure of the dataset uploaded for single dataset training in this example
  9. MNISTData.zip
  10. ├── test
  11. │ ├── t10k-images-idx3-ubyte
  12. │ └── t10k-labels-idx1-ubyte
  13. └── train
  14. ├── train-images-idx3-ubyte
  15. └── train-labels-idx1-ubyte
  16. (2)The dataset structure of the single dataset in the training image in this example
  17. workroot
  18. ├── data
  19. | ├── test
  20. | └── train
  21. 2、Single dataset training requires predefined functions
  22. (1)Defines whether the task is a training environment or a debugging environment.
  23. def WorkEnvironment(environment):
  24. if environment == 'train':
  25. workroot = '/home/work/user-job-dir' #The training task uses this parameter to represent the local path of the training image
  26. elif environment == 'debug':
  27. workroot = '/home/ma-user/work' #The debug task uses this parameter to represent the local path of the debug image
  28. print('current work mode:' + environment + ', workroot:' + workroot)
  29. return workroot
  30. (2)Copy single dataset from obs to training image.
  31. def ObsToEnv(obs_data_url, data_dir):
  32. try:
  33. mox.file.copy_parallel(obs_data_url, data_dir)
  34. print("Successfully Download {} to {}".format(obs_data_url, data_dir))
  35. except Exception as e:
  36. print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
  37. return
  38. (3)Copy the output model to obs.
  39. def EnvToObs(train_dir, obs_train_url):
  40. try:
  41. mox.file.copy_parallel(train_dir, obs_train_url)
  42. print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
  43. except Exception as e:
  44. print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
  45. return
  46. 3、3 parameters need to be defined
  47. --data_url is the dataset you selected on the Qizhi platform
  48. --data_url,--train_url,--device_target,These 3 parameters must be defined first in a single dataset task,
  49. otherwise an error will be reported.
  50. There is no need to add these parameters to the running parameters of the Qizhi platform,
  51. because they are predefined in the background, you only need to define them in your code.
  52. 4、How the dataset is used
  53. A single dataset uses data_url as the input, and data_dir (ie: workroot + '/data') as the calling method
  54. of the dataset in the image.
  55. For details, please refer to the following sample code.
  56. """
  57. import os
  58. import argparse
  59. from dataset_distributed import create_dataset_parallel
  60. import moxing as mox
  61. from config import mnist_cfg as cfg
  62. from dataset import create_dataset
  63. from lenet import LeNet5
  64. import mindspore.nn as nn
  65. from mindspore import context
  66. from mindspore.common import set_seed
  67. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  68. from mindspore.train import Model
  69. from mindspore.nn.metrics import Accuracy
  70. from mindspore.context import ParallelMode
  71. from mindspore.communication.management import init, get_rank, get_group_size
  72. import mindspore.ops as ops
  73. # set device_id and init
  74. device_id = int(os.getenv('DEVICE_ID'))
  75. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  76. context.set_context(device_id=device_id)
  77. init()
  78. ### Defines whether the task is a training environment or a debugging environment ###
  79. def WorkEnvironment(environment):
  80. if environment == 'train':
  81. workroot = '/home/work/user-job-dir'
  82. elif environment == 'debug':
  83. workroot = '/home/work'
  84. print('current work mode:' + environment + ', workroot:' + workroot)
  85. return workroot
  86. ### Copy single dataset from obs to training image###
  87. def ObsToEnv(obs_data_url, data_dir):
  88. try:
  89. mox.file.copy_parallel(obs_data_url, data_dir)
  90. print("Successfully Download {} to {}".format(obs_data_url, data_dir))
  91. except Exception as e:
  92. print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
  93. return
  94. ### Copy the output model to obs###
  95. def EnvToObs(train_dir, obs_train_url):
  96. try:
  97. mox.file.copy_parallel(train_dir, obs_train_url)
  98. print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
  99. except Exception as e:
  100. print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
  101. return
  102. ### --data_url,--train_url,--device_target,These 3 parameters must be defined first in a single dataset,
  103. ### otherwise an error will be reported.
  104. ###There is no need to add these parameters to the running parameters of the Qizhi platform,
  105. ###because they are predefined in the background, you only need to define them in your code.
  106. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  107. parser.add_argument('--data_url',
  108. help='path to training/inference dataset folder',
  109. default= WorkEnvironment('train') + '/data/')
  110. parser.add_argument('--train_url',
  111. help='model folder to save/load',
  112. default= WorkEnvironment('train') + '/model/')
  113. parser.add_argument(
  114. '--device_target',
  115. type=str,
  116. default="Ascend",
  117. choices=['Ascend', 'CPU'],
  118. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  119. parser.add_argument('--epoch_size',
  120. type=int,
  121. default=5,
  122. help='Training epochs.')
  123. set_seed(114514)
  124. if __name__ == "__main__":
  125. args = parser.parse_args()
  126. ### defining the training environment
  127. environment = 'train'
  128. workroot = WorkEnvironment(environment)
  129. ###Initialize the data and model directories in the training image###
  130. data_dir = workroot + '/data'
  131. train_dir = workroot + '/model'
  132. if not os.path.exists(data_dir):
  133. os.makedirs(data_dir)
  134. if not os.path.exists(train_dir):
  135. os.makedirs(train_dir)
  136. ### Copy the dataset from obs to the training image ###
  137. ObsToEnv(args.data_url,data_dir)
  138. context.reset_auto_parallel_context()
  139. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
  140. ds_train = create_dataset_parallel(os.path.join(data_dir, "train"),
  141. cfg.batch_size)
  142. if ds_train.get_dataset_size() == 0:
  143. raise ValueError(
  144. "Please check dataset size > 0 and batch_size <= dataset size")
  145. network = LeNet5(cfg.num_classes)
  146. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  147. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  148. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  149. if args.device_target != "Ascend":
  150. model = Model(network,
  151. net_loss,
  152. net_opt,
  153. metrics={"accuracy": Accuracy()})
  154. else:
  155. model = Model(network,
  156. net_loss,
  157. net_opt,
  158. metrics={"accuracy": Accuracy()},
  159. amp_level="O2")
  160. config_ck = CheckpointConfig(
  161. save_checkpoint_steps=cfg.save_checkpoint_steps,
  162. keep_checkpoint_max=cfg.keep_checkpoint_max)
  163. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  164. # In the example, get_rank() is added to distinguish different paths.
  165. ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
  166. directory=train_dir + "/" + str(get_rank()) + "/",
  167. config=config_ck)
  168. print("============== Starting Training ==============")
  169. epoch_size = cfg['epoch_size']
  170. if (args.epoch_size):
  171. epoch_size = args.epoch_size
  172. print('epoch_size is: ', epoch_size)
  173. model.train(epoch_size,
  174. ds_train,
  175. callbacks=[time_cb, ckpoint_cb,
  176. LossMonitor()], dataset_sink_mode=True)
  177. ###Copy the trained model data from the local running environment back to obs,
  178. ###and download it in the training task corresponding to the Qizhi platform
  179. EnvToObs(train_dir, args.train_url)

No Description