You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

c2net_npu_pretrain.py 10 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. """
  2. ######################## single-dataset train lenet example ########################
  3. This example is a single-dataset training tutorial. If it is a multi-dataset, please refer to the multi-dataset training
  4. tutorial train_for_multidataset.py. This example cannot be used for multi-datasets!
  5. ######################## Instructions for using the training environment ########################
  6. The image of the debugging environment and the image of the training environment are two different images,
  7. and the working local directories are different. In the training task, you need to pay attention to the following points.
  8. 1、(1)The structure of the dataset uploaded for single dataset training in this example
  9. MNISTData.zip
  10. ├── test
  11. └── train
  12. 2、Single dataset training requires predefined functions
  13. (1)Copy single dataset from obs to training image
  14. function ObsToEnv(obs_data_url, data_dir)
  15. (2)Copy the output to obs
  16. function EnvToObs(train_dir, obs_train_url)
  17. (3)Download the input from Qizhi And Init
  18. function DownloadFromQizhi(obs_data_url, data_dir)
  19. (4)Upload the output to Qizhi
  20. function UploadToQizhi(train_dir, obs_train_url)
  21. (5)Copy ckpt file from obs to training image.
  22. function ObsUrlToEnv(obs_ckpt_url, ckpt_url)
  23. 3、3 parameters need to be defined
  24. --data_url is the dataset you selected on the Qizhi platform
  25. --data_url,--train_url,--device_target,These 3 parameters must be defined first in a single dataset task,
  26. otherwise an error will be reported.
  27. There is no need to add these parameters to the running parameters of the Qizhi platform,
  28. because they are predefined in the background, you only need to define them in your code.
  29. 4、How the dataset is used
  30. A single dataset uses data_url as the input, and data_dir (ie:'/cache/data') as the calling method
  31. of the dataset in the image.
  32. For details, please refer to the following sample code.
  33. 5、How to load the checkpoint file
  34. The checkpoint file is loaded by the ckpt_url parameter
  35. """
  36. import os
  37. import argparse
  38. import moxing as mox
  39. from config import mnist_cfg as cfg
  40. from dataset import create_dataset
  41. from dataset_distributed import create_dataset_parallel
  42. from lenet import LeNet5
  43. import mindspore.nn as nn
  44. from mindspore import context
  45. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  46. from mindspore import load_checkpoint, load_param_into_net
  47. from mindspore.train import Model
  48. from mindspore.nn.metrics import Accuracy
  49. from mindspore.context import ParallelMode
  50. from mindspore.communication.management import init, get_rank
  51. import mindspore.ops as ops
  52. import time
  53. import json
  54. #from upload import UploadOutput
  55. ### Copy single dataset from obs to training image###
  56. def ObsToEnv(obs_data_url, data_dir):
  57. try:
  58. mox.file.copy_parallel(obs_data_url, data_dir)
  59. print("Successfully Download {} to {}".format(obs_data_url, data_dir))
  60. except Exception as e:
  61. print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
  62. #Set a cache file to determine whether the data has been copied to obs.
  63. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  64. f = open("/cache/download_input.txt", 'w')
  65. f.close()
  66. try:
  67. if os.path.exists("/cache/download_input.txt"):
  68. print("download_input succeed")
  69. except Exception as e:
  70. print("download_input failed")
  71. return
  72. ### Copy ckpt file from obs to training image###
  73. ### To operate on folders, use mox.file.copy_parallel. If copying a file.
  74. ### Please use mox.file.copy to operate the file, this operation is to operate the file
  75. def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
  76. try:
  77. mox.file.copy(obs_ckpt_url, ckpt_url)
  78. print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
  79. except Exception as e:
  80. print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
  81. return
  82. ### Copy multiple datasets from obs to training image ###
  83. def MultiObsToEnv(multi_data_url, data_dir):
  84. #--multi_data_url is json data, need to do json parsing for multi_data_url
  85. multi_data_json = json.loads(multi_data_url)
  86. for i in range(len(multi_data_json)):
  87. path = data_dir + "/" + multi_data_json[i]["dataset_name"]
  88. file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0]
  89. if not os.path.exists(file_path):
  90. os.makedirs(file_path)
  91. try:
  92. mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
  93. print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
  94. #unzip dataset
  95. os.system("unzip -d %s %s" % (file_path, path))
  96. except Exception as e:
  97. print('moxing download {} to {} failed: '.format(
  98. multi_data_json[i]["dataset_url"], path) + str(e))
  99. #Set a cache file to determine whether the data has been copied to obs.
  100. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  101. f = open("/cache/download_input.txt", 'w')
  102. f.close()
  103. try:
  104. if os.path.exists("/cache/download_input.txt"):
  105. print("download_input succeed")
  106. except Exception as e:
  107. print("download_input failed")
  108. return
  109. def DownloadFromQizhi(multi_data_url, data_dir):
  110. device_num = int(os.getenv('RANK_SIZE'))
  111. if device_num == 1:
  112. MultiObsToEnv(multi_data_url,data_dir)
  113. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  114. if device_num > 1:
  115. # set device_id and init for multi-card training
  116. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  117. context.reset_auto_parallel_context()
  118. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  119. init()
  120. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  121. local_rank=int(os.getenv('RANK_ID'))
  122. if local_rank%8==0:
  123. MultiObsToEnv(multi_data_url,data_dir)
  124. #If the cache file does not exist, it means that the copy data has not been completed,
  125. #and Wait for 0th card to finish copying data
  126. while not os.path.exists("/cache/download_input.txt"):
  127. time.sleep(1)
  128. return
  129. ### --data_url,--train_url,--device_target,These 3 parameters must be defined first in a single dataset,
  130. ### otherwise an error will be reported.
  131. ###There is no need to add these parameters to the running parameters of the Qizhi platform,
  132. ###because they are predefined in the background, you only need to define them in your code.
  133. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  134. parser.add_argument('--multi_data_url',
  135. help='dataset path in obs')
  136. parser.add_argument('--ckpt_url',
  137. help='pre_train_model path in obs')
  138. parser.add_argument(
  139. '--device_target',
  140. type=str,
  141. default="Ascend",
  142. choices=['Ascend', 'CPU'],
  143. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  144. parser.add_argument('--epoch_size',
  145. type=int,
  146. default=5,
  147. help='Training epochs.')
  148. if __name__ == "__main__":
  149. args, unknown = parser.parse_known_args()
  150. data_dir = '/cache/dataset'
  151. train_dir = '/cache/output'
  152. ckpt_url = '/cache/checkpoint.ckpt'
  153. if not os.path.exists(data_dir):
  154. os.makedirs(data_dir)
  155. if not os.path.exists(train_dir):
  156. os.makedirs(train_dir)
  157. ###Initialize and copy data to training image
  158. ###Copy ckpt file from obs to training image
  159. ObsUrlToEnv(args.ckpt_url, ckpt_url)
  160. ###Copy data from obs to training image
  161. DownloadFromQizhi(args.multi_data_url, data_dir)
  162. ###The dataset path is used here:data_dir +"/train"
  163. device_num = int(os.getenv('RANK_SIZE'))
  164. if device_num == 1:
  165. ds_train = create_dataset(os.path.join(data_dir+ "/MNISTData", "train"), cfg.batch_size)
  166. if device_num > 1:
  167. ds_train = create_dataset_parallel(os.path.join(data_dir+ "/MNISTData", "train"), cfg.batch_size)
  168. if ds_train.get_dataset_size() == 0:
  169. raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
  170. network = LeNet5(cfg.num_classes)
  171. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  172. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  173. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  174. ###The ckpt path is used here:ckpt_url
  175. print('-------ckpt_url is:', args.ckpt_url)
  176. load_param_into_net(network, load_checkpoint(ckpt_url))
  177. if args.device_target != "Ascend":
  178. model = Model(network,
  179. net_loss,
  180. net_opt,
  181. metrics={"accuracy": Accuracy()})
  182. else:
  183. model = Model(network,
  184. net_loss,
  185. net_opt,
  186. metrics={"accuracy": Accuracy()},
  187. amp_level="O2")
  188. config_ck = CheckpointConfig(
  189. save_checkpoint_steps=cfg.save_checkpoint_steps,
  190. keep_checkpoint_max=cfg.keep_checkpoint_max)
  191. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  192. # In this example, get_rank() is added to distinguish different paths.
  193. if device_num == 1:
  194. outputDirectory = train_dir + "/"
  195. if device_num > 1:
  196. outputDirectory = train_dir + "/" + str(get_rank()) + "/"
  197. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  198. directory=outputDirectory,
  199. config=config_ck)
  200. print("============== Starting Training ==============")
  201. epoch_size = cfg['epoch_size']
  202. if (args.epoch_size):
  203. epoch_size = args.epoch_size
  204. print('epoch_size is: ', epoch_size)
  205. model.train(epoch_size,
  206. ds_train,
  207. callbacks=[time_cb, ckpoint_cb,
  208. LossMonitor()])