You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretrain_for_c2net.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """
  2. ######################## Attention! ########################
  3. 智算网络需要在代码里使用mox拷贝数据集并解压,请参考函数C2netMultiObsToEnv
  4. The intelligent computing network needs to use mox to copy the dataset and decompress it in the code,
  5. please refer to the function C2netMultiObsToEnv()
  6. ######################## multi-dataset train lenet example ########################
  7. This example is a multi-dataset training tutorial. If it is a single dataset, please refer to the single dataset
  8. training tutorial train.py. This example cannot be used for a single dataset!
  9. """
  10. """
  11. ######################## Instructions for using the training environment ########################
  12. 1、(1)The structure of the dataset uploaded for multi-dataset training in this example
  13. MNISTData.zip
  14. ├── test
  15. └── train
  16. checkpoint_lenet-1_1875.zip
  17. ├── checkpoint_lenet-1_1875.ckpt
  18. (2)The dataset structure in the training image for multiple datasets in this example
  19. workroot
  20. ├── MNISTData
  21. | ├── test
  22. | └── train
  23. └── checkpoint_lenet-1_1875
  24. ├── checkpoint_lenet-1_1875.ckpt
  25. 2、Multi-dataset training requires predefined functions
  26. (1)Copy multi-dataset from obs to training image and unzip
  27. function C2netMultiObsToEnv(multi_data_url, data_dir)
  28. (2)Copy the output to obs
  29. function EnvToObs(train_dir, obs_train_url)
  30. (2)Download the input from Qizhi And Init
  31. function DownloadFromQizhi(multi_data_url, data_dir)
  32. (2)Upload the output to Qizhi
  33. function UploadToQizhi(train_dir, obs_train_url)
  34. 3、4 parameters need to be defined
  35. --multi_data_url is the multi-dataset you selected on the Qizhi platform
  36. --multi_data_url,--train_url,--device_target,These 3 parameters must be defined first in a multi-dataset task,
  37. otherwise an error will be reported.
  38. There is no need to add these parameters to the running parameters of the Qizhi platform,
  39. because they are predefined in the background, you only need to define them in your code
  40. 4、How the dataset is used
  41. Multi-datasets use multi_data_url as input, data_dir + dataset name + file or folder name in the dataset as the
  42. calling path of the dataset in the training image.
  43. For example, the calling path of the train folder in the MNIST_Data dataset in this example is
  44. data_dir + "/MNIST_Data" +"/train"
  45. For details, please refer to the following sample code.
  46. """
  47. import os
  48. import argparse
  49. import moxing as mox
  50. from config import mnist_cfg as cfg
  51. from dataset import create_dataset
  52. from dataset_distributed import create_dataset_parallel
  53. from lenet import LeNet5
  54. import json
  55. import mindspore.nn as nn
  56. from mindspore import context
  57. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  58. from mindspore.train import Model
  59. from mindspore.nn.metrics import Accuracy
  60. from mindspore import load_checkpoint, load_param_into_net
  61. from mindspore.context import ParallelMode
  62. from mindspore.communication.management import init, get_rank
  63. import time
  64. ### Copy multiple datasets from obs to training image and unzip###
  65. def C2netMultiObsToEnv(multi_data_url, data_dir):
  66. #--multi_data_url is json data, need to do json parsing for multi_data_url
  67. multi_data_json = json.loads(multi_data_url)
  68. for i in range(len(multi_data_json)):
  69. zipfile_path = data_dir + "/" + multi_data_json[i]["dataset_name"]
  70. try:
  71. mox.file.copy(multi_data_json[i]["dataset_url"], zipfile_path)
  72. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],zipfile_path))
  73. #get filename and unzip the dataset
  74. filename = os.path.splitext(multi_data_json[i]["dataset_name"])[0]
  75. filePath = data_dir + "/" + filename
  76. if not os.path.exists(filePath):
  77. os.makedirs(filePath)
  78. os.system("unzip {} -d {}".format(zipfile_path, filePath))
  79. except Exception as e:
  80. print('moxing download {} to {} failed: '.format(
  81. multi_data_json[i]["dataset_url"], zipfile_path) + str(e))
  82. #Set a cache file to determine whether the data has been copied to obs.
  83. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  84. f = open("/cache/download_input.txt", 'w')
  85. f.close()
  86. try:
  87. if os.path.exists("/cache/download_input.txt"):
  88. print("download_input succeed")
  89. except Exception as e:
  90. print("download_input failed")
  91. return
  92. ### Copy ckpt file from obs to training image###
  93. ### To operate on folders, use mox.file.copy_parallel. If copying a file.
  94. ### Please use mox.file.copy to operate the file, this operation is to operate the file
  95. def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
  96. try:
  97. mox.file.copy(obs_ckpt_url, ckpt_url)
  98. print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
  99. except Exception as e:
  100. print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
  101. return
  102. ### Copy the output model to obs ###
  103. def EnvToObs(train_dir, obs_train_url):
  104. try:
  105. mox.file.copy_parallel(train_dir, obs_train_url)
  106. print("Successfully Upload {} to {}".format(train_dir,
  107. obs_train_url))
  108. except Exception as e:
  109. print('moxing upload {} to {} failed: '.format(train_dir,
  110. obs_train_url) + str(e))
  111. return
  112. def DownloadFromQizhi(multi_data_url, data_dir):
  113. device_num = int(os.getenv('RANK_SIZE'))
  114. if device_num == 1:
  115. C2netMultiObsToEnv(multi_data_url,data_dir)
  116. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  117. if device_num > 1:
  118. # set device_id and init for multi-card training
  119. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  120. context.reset_auto_parallel_context()
  121. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  122. init()
  123. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  124. local_rank=int(os.getenv('RANK_ID'))
  125. if local_rank%8==0:
  126. C2netMultiObsToEnv(multi_data_url,data_dir)
  127. #If the cache file does not exist, it means that the copy data has not been completed,
  128. #and Wait for 0th card to finish copying data
  129. while not os.path.exists("/cache/download_input.txt"):
  130. time.sleep(1)
  131. return
  132. def UploadToQizhi(train_dir, obs_train_url):
  133. device_num = int(os.getenv('RANK_SIZE'))
  134. local_rank=int(os.getenv('RANK_ID'))
  135. if device_num == 1:
  136. EnvToObs(train_dir, obs_train_url)
  137. if device_num > 1:
  138. if local_rank%8==0:
  139. EnvToObs(train_dir, obs_train_url)
  140. return
  141. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  142. ### --multi_data_url,--train_url,--device_target,These 3 parameters must be defined first in a multi-dataset,
  143. ### otherwise an error will be reported.
  144. ### There is no need to add these parameters to the running parameters of the Qizhi platform,
  145. ### because they are predefined in the background, you only need to define them in your code.
  146. parser.add_argument('--multi_data_url',
  147. help='path to multi dataset',
  148. default= '/cache/data/')
  149. parser.add_argument('--ckpt_url',
  150. help='pre_train_model path in obs')
  151. parser.add_argument('--train_url',
  152. help='model folder to save/load',
  153. default= '/cache/output/')
  154. parser.add_argument(
  155. '--device_target',
  156. type=str,
  157. default="Ascend",
  158. choices=['Ascend', 'CPU'],
  159. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  160. parser.add_argument('--epoch_size',
  161. type=int,
  162. default=5,
  163. help='Training epochs.')
  164. if __name__ == "__main__":
  165. args, unknown = parser.parse_known_args()
  166. data_dir = '/cache/data'
  167. train_dir = '/cache/output'
  168. ckpt_url = '/cache/checkpoint.ckpt'
  169. if not os.path.exists(data_dir):
  170. os.makedirs(data_dir)
  171. if not os.path.exists(train_dir):
  172. os.makedirs(train_dir)
  173. ###Copy ckpt file from obs to training image
  174. ObsUrlToEnv(args.ckpt_url, ckpt_url)
  175. ###Initialize and copy data to training image
  176. DownloadFromQizhi(args.multi_data_url, data_dir)
  177. ###The dataset path is used here:data_dir + "/MNIST_Data" +"/train"
  178. device_num = int(os.getenv('RANK_SIZE'))
  179. if device_num == 1:
  180. ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  181. if device_num > 1:
  182. ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  183. if ds_train.get_dataset_size() == 0:
  184. raise ValueError(
  185. "Please check dataset size > 0 and batch_size <= dataset size")
  186. network = LeNet5(cfg.num_classes)
  187. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  188. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  189. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  190. ###The ckpt path is used here:ckpt_url
  191. load_param_into_net(network, load_checkpoint(ckpt_url))
  192. if args.device_target != "Ascend":
  193. model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()})
  194. else:
  195. model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2")
  196. config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
  197. keep_checkpoint_max=cfg.keep_checkpoint_max)
  198. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  199. # In this example, get_rank() is added to distinguish different paths.
  200. if device_num == 1:
  201. outputDirectory = train_dir
  202. if device_num > 1:
  203. outputDirectory = train_dir + "/" + str(get_rank()) + "/"
  204. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  205. directory=outputDirectory,
  206. config=config_ck)
  207. print("============== Starting Training ==============")
  208. epoch_size = cfg['epoch_size']
  209. if (args.epoch_size):
  210. epoch_size = args.epoch_size
  211. print('epoch_size is: ', epoch_size)
  212. model.train(epoch_size,
  213. ds_train,
  214. callbacks=[time_cb, ckpoint_cb,
  215. LossMonitor()])
  216. ###Copy the trained output data from the local running environment back to obs,
  217. ###and download it in the training task corresponding to the Qizhi platform
  218. UploadToQizhi(train_dir,args.train_url)