You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. ######################## train lenet example ########################
  3. train lenet and get network model files(.ckpt)
  4. """
  5. """
  6. ######################## 训练环境使用说明 ########################
  7. 假设已经使用Ascend NPU调试环境调试完代码,欲将调试环境的代码迁移到训练环境进行训练,需要做以下工作:
  8. 1、调试环境的镜像和训练环境的镜像是两个不同的镜像,所处的运行目录不一致,需要将data_url和train_url的路径进行变换
  9. 在调试环境中:
  10. args.data_url = '/home/ma-user/work/data/' //数据集位置
  11. args.train_url = '/home/ma-user/work/model/' //训练输出的模型位置
  12. 在训练环境变换为:
  13. args.data_url = '/home/work/user-job-dir/data/'
  14. args.train_url = '/home/work/user-job-dir/model/'
  15. 2、在训练环境中,需要将数据集从obs拷贝到训练镜像中,训练完以后,需要将输出的模型拷贝到obs.
  16. 将数据集从obs拷贝到训练镜像中:
  17. obs_data_url = args.data_url
  18. args.data_url = '/home/work/user-job-dir/data/'
  19. if not os.path.exists(args.data_url):
  20. os.mkdir(args.data_url)
  21. try:
  22. mox.file.copy_parallel(obs_data_url, args.data_url)
  23. print("Successfully Download {} to {}".format(obs_data_url,
  24. args.data_url))
  25. except Exception as e:
  26. print('moxing download {} to {} failed: '.format(
  27. obs_data_url, args.data_url) + str(e))
  28. 将输出的模型拷贝到obs:
  29. obs_train_url = args.train_url
  30. args.train_url = '/home/work/user-job-dir/model/'
  31. if not os.path.exists(args.train_url):
  32. os.mkdir(args.train_url)
  33. try:
  34. mox.file.copy_parallel(args.train_url, obs_train_url)
  35. print("Successfully Upload {} to {}".format(args.train_url,
  36. obs_train_url))
  37. except Exception as e:
  38. print('moxing upload {} to {} failed: '.format(args.train_url,
  39. obs_train_url) + str(e))
  40. """
  41. import os
  42. import numpy as np
  43. import argparse
  44. import moxing as mox
  45. from config import mnist_cfg as cfg
  46. from dataset import create_dataset
  47. from lenet import LeNet5
  48. import mindspore.nn as nn
  49. from mindspore import context
  50. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  51. from mindspore.train import Model
  52. from mindspore.nn.metrics import Accuracy
  53. from mindspore.common import set_seed
  54. from mindspore import Tensor, export
  55. #配置默认的工作空间根目录
  56. # environment = 'debug'
  57. environment = 'train'
  58. if environment == 'debug':
  59. workroot = '/home/ma-user/work' #调试任务使用该参数
  60. else:
  61. workroot = '/home/work/user-job-dir' # 训练任务使用该参数
  62. print('current work mode:' + environment + ', workroot:' + workroot)
  63. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  64. # define 2 parameters for running on modelArts
  65. # data_url,train_url是固定用于在modelarts上训练的参数,表示数据集的路径和输出模型的路径
  66. parser.add_argument('--data_url',
  67. help='path to training/inference dataset folder',
  68. default= workroot + '/data/')
  69. parser.add_argument('--train_url',
  70. help='model folder to save/load',
  71. default= workroot + '/model/')
  72. parser.add_argument(
  73. '--device_target',
  74. type=str,
  75. default="Ascend",
  76. choices=['Ascend', 'CPU'],
  77. help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
  78. #modelarts已经默认使用data_url和train_url
  79. parser.add_argument('--epoch_size',
  80. type=int,
  81. default=5,
  82. help='Training epochs.')
  83. set_seed(1)
  84. if __name__ == "__main__":
  85. args = parser.parse_args()
  86. print('args:')
  87. print(args)
  88. data_dir = workroot + '/data' #数据集存放路径
  89. train_dir = workroot + '/model' #模型存放路径
  90. #初始化数据存放目录
  91. if not os.path.exists(data_dir):
  92. os.mkdir(data_dir)
  93. #初始化模型存放目录
  94. obs_train_url = args.train_url
  95. train_dir = workroot + '/model/'
  96. if not os.path.exists(train_dir):
  97. os.mkdir(train_dir)
  98. ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
  99. # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
  100. #创建数据存放的位置
  101. if environment == 'train':
  102. obs_data_url = args.data_url
  103. #将数据拷贝到训练环境
  104. try:
  105. mox.file.copy_parallel(obs_data_url, data_dir)
  106. print("Successfully Download {} to {}".format(obs_data_url,
  107. data_dir))
  108. except Exception as e:
  109. print('moxing download {} to {} failed: '.format(
  110. obs_data_url, data_dir) + str(e))
  111. ######################## 将数据集从obs拷贝到训练镜像中 ########################
  112. #注意:这里很重要,指定了训练所用的设备CPU还是Ascend NPU
  113. context.set_context(mode=context.GRAPH_MODE,
  114. device_target=args.device_target)
  115. #创建数据集
  116. ds_train = create_dataset(os.path.join(data_dir, "train"),
  117. cfg.batch_size)
  118. if ds_train.get_dataset_size() == 0:
  119. raise ValueError(
  120. "Please check dataset size > 0 and batch_size <= dataset size")
  121. #创建网络
  122. network = LeNet5(cfg.num_classes)
  123. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  124. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  125. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  126. if args.device_target != "Ascend":
  127. model = Model(network,
  128. net_loss,
  129. net_opt,
  130. metrics={"accuracy": Accuracy()})
  131. else:
  132. model = Model(network,
  133. net_loss,
  134. net_opt,
  135. metrics={"accuracy": Accuracy()},
  136. amp_level="O2")
  137. config_ck = CheckpointConfig(
  138. save_checkpoint_steps=cfg.save_checkpoint_steps,
  139. keep_checkpoint_max=cfg.keep_checkpoint_max)
  140. #定义模型输出路径
  141. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  142. directory=train_dir,
  143. config=config_ck)
  144. #开始训练
  145. print("============== Starting Training ==============")
  146. epoch_size = cfg['epoch_size']
  147. if (args.epoch_size):
  148. epoch_size = args.epoch_size
  149. print('epoch_size is: ', epoch_size)
  150. model.train(epoch_size,
  151. ds_train,
  152. callbacks=[time_cb, ckpoint_cb,
  153. LossMonitor()])
  154. input = np.random.uniform(0.0, 1.0, size=[1, 1, 32, 32]).astype(np.float32)
  155. export(network, Tensor(input), file_name=(train_dir +'LeNet5_model'), file_format='MINDIR')
  156. export(network, Tensor(input), file_name=(train_dir +'LeNet5_onnx_model'), file_format='ONNX')
  157. ######################## 将输出的模型拷贝到obs(固定写法) ########################
  158. # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
  159. if environment == 'train':
  160. try:
  161. mox.file.copy_parallel(train_dir, obs_train_url)
  162. print("Successfully Upload {} to {}".format(train_dir,
  163. obs_train_url))
  164. except Exception as e:
  165. print('moxing upload {} to {} failed: '.format(train_dir,
  166. obs_train_url) + str(e))
  167. ######################## 将输出的模型拷贝到obs ########################