You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_npu.py 5.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. 示例选用的数据集是MNISTData.zip
  3. 数据集结构是:
  4. MNISTData.zip
  5. ├── test
  6. │ ├── t10k-images-idx3-ubyte
  7. │ └── t10k-labels-idx1-ubyte
  8. └── train
  9. ├── train-images-idx3-ubyte
  10. └── train-labels-idx1-ubyte
  11. 使用注意事项:
  12. 1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  13. 2、用户需要调用openi的python sdk包
  14. """
  15. import os
  16. import argparse
  17. from config import mnist_cfg as cfg
  18. from dataset import create_dataset
  19. from dataset_distributed import create_dataset_parallel
  20. from lenet import LeNet5
  21. import mindspore.nn as nn
  22. from mindspore import context
  23. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  24. from mindspore.train import Model
  25. from mindspore.context import ParallelMode
  26. from mindspore.communication.management import init, get_rank
  27. import time
  28. #导入openi包
  29. from openi.context import prepare, upload_openi
  30. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  31. parser.add_argument(
  32. '--device_target',
  33. type=str,
  34. default="Ascend",
  35. choices=['Ascend', 'CPU'],
  36. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  37. parser.add_argument('--epoch_size',
  38. type=int,
  39. default=5,
  40. help='Training epochs.')
  41. if __name__ == "__main__":
  42. ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  43. args, unknown = parser.parse_known_args()
  44. data_dir = ''
  45. pretrain_dir = ''
  46. train_dir = ''
  47. #回传结果到openi
  48. upload_openi()
  49. device_num = int(os.getenv('RANK_SIZE'))
  50. #使用单卡时
  51. if device_num == 1:
  52. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  53. #初始化导入数据集和预训练模型到容器内
  54. openi_context = prepare()
  55. data_dir = openi_context.dataset_path
  56. pretrain_dir = openi_context.pretrain_model_path
  57. train_dir = openi_context.output_path
  58. #使用数据集的方式
  59. ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  60. #使用多卡时
  61. if device_num > 1:
  62. # set device_id and init for multi-card training
  63. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  64. context.reset_auto_parallel_context()
  65. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  66. init()
  67. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  68. local_rank=int(os.getenv('RANK_ID'))
  69. if local_rank%8==0:
  70. ###初始化导入数据集和预训练模型到容器内
  71. openi_context = prepare()
  72. #初始化导入数据集和预训练模型到容器内
  73. openi_context = prepare()
  74. data_dir = openi_context.dataset_path
  75. pretrain_dir = openi_context.pretrain_model_path
  76. train_dir = openi_context.output_path
  77. #Set a cache file to determine whether the data has been copied to obs.
  78. #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
  79. f = open("/cache/download_input.txt", 'w')
  80. f.close()
  81. try:
  82. if os.path.exists("/cache/download_input.txt"):
  83. print("download_input succeed")
  84. except Exception as e:
  85. print("download_input failed")
  86. while not os.path.exists("/cache/download_input.txt"):
  87. time.sleep(1)
  88. ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  89. network = LeNet5(cfg.num_classes)
  90. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  91. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  92. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  93. if args.device_target != "Ascend":
  94. model = Model(network,
  95. net_loss,
  96. net_opt,
  97. metrics={"accuracy"})
  98. else:
  99. model = Model(network,
  100. net_loss,
  101. net_opt,
  102. metrics={"accuracy"},
  103. amp_level="O2")
  104. config_ck = CheckpointConfig(
  105. save_checkpoint_steps=cfg.save_checkpoint_steps,
  106. keep_checkpoint_max=cfg.keep_checkpoint_max)
  107. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  108. # In this example, get_rank() is added to distinguish different paths.
  109. if device_num == 1:
  110. outputDirectory = train_dir + "/"
  111. if device_num > 1:
  112. outputDirectory = train_dir + "/" + str(get_rank()) + "/"
  113. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  114. directory=outputDirectory,
  115. config=config_ck)
  116. print("============== Starting Training ==============")
  117. epoch_size = cfg['epoch_size']
  118. if (args.epoch_size):
  119. epoch_size = args.epoch_size
  120. print('epoch_size is: ', epoch_size)
  121. model.train(epoch_size, ds_train,callbacks=[time_cb, ckpoint_cb,LossMonitor()])
  122. ###上传训练结果到启智平台,注意必须将要输出的模型存储在openi_context.output_path
  123. upload_openi()

No Description