You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_multi_card.py 4.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """
  2. 示例选用的数据集是MnistDataset_mindspore.zip
  3. 数据集结构是:
  4. MnistDataset_mindspore.zip
  5. ├── test
  6. │ ├── t10k-images-idx3-ubyte
  7. │ └── t10k-labels-idx1-ubyte
  8. └── train
  9. ├── train-images-idx3-ubyte
  10. └── train-labels-idx1-ubyte
  11. 使用注意事项:
  12. 1、在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  13. 2、用户需要调用c2net的python sdk包
  14. """
  15. import os
  16. import argparse
  17. from config import mnist_cfg as cfg
  18. from dataset_distributed import create_dataset_parallel
  19. from lenet import LeNet5
  20. import mindspore.nn as nn
  21. from mindspore import context
  22. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  23. from mindspore.train import Model
  24. from mindspore.context import ParallelMode
  25. from mindspore.communication.management import init, get_rank
  26. import time
  27. from c2net.context import prepare, upload_output
  28. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  29. parser.add_argument(
  30. '--device_target',
  31. type=str,
  32. default="Ascend",
  33. choices=['Ascend', 'CPU'],
  34. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  35. parser.add_argument('--epoch_size',
  36. type=int,
  37. default=5,
  38. help='Training epochs.')
  39. if __name__ == "__main__":
  40. ###请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
  41. args, unknown = parser.parse_known_args()
  42. device_num = int(os.getenv('RANK_SIZE'))
  43. #使用多卡时
  44. # set device_id and init for multi-card training
  45. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  46. context.reset_auto_parallel_context()
  47. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  48. init()
  49. #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
  50. local_rank=int(os.getenv('RANK_ID'))
  51. #初始化导入数据集和预训练模型到容器内,并行任务先让0卡拷贝数据,并用一个缓存文件标记0卡已prepare完成
  52. if local_rank == 0:
  53. c2net_context = prepare()
  54. f = open("/cache/prepare_completed.txt", 'w')
  55. f.close()
  56. try:
  57. if os.path.exists("/cache/prepare_completed.txt"):
  58. print("prepare completed!")
  59. except Exception as e:
  60. print("prepare failed")
  61. while not os.path.exists("/cache/prepare_completed.txt"):
  62. time.sleep(1)
  63. c2net_context = prepare()
  64. #获取数据集路径
  65. MnistDataset_mindspore_path = c2net_context.dataset_path+"/"+"MnistDataset_mindspore"
  66. output_path = c2net_context.output_path
  67. ds_train = create_dataset_parallel(os.path.join(MnistDataset_mindspore_path, "train"), cfg.batch_size)
  68. network = LeNet5(cfg.num_classes)
  69. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  70. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  71. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  72. if args.device_target != "Ascend":
  73. model = Model(network,
  74. net_loss,
  75. net_opt,
  76. metrics={"accuracy"})
  77. else:
  78. model = Model(network,
  79. net_loss,
  80. net_opt,
  81. metrics={"accuracy"},
  82. amp_level="O2")
  83. config_ck = CheckpointConfig(
  84. save_checkpoint_steps=cfg.save_checkpoint_steps,
  85. keep_checkpoint_max=cfg.keep_checkpoint_max)
  86. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  87. # In this example, get_rank() is added to distinguish different paths.
  88. outputDirectory = output_path + "/" + str(get_rank()) + "/"
  89. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  90. directory=outputDirectory,
  91. config=config_ck)
  92. print("============== Starting Training ==============")
  93. epoch_size = cfg['epoch_size']
  94. if (args.epoch_size):
  95. epoch_size = args.epoch_size
  96. print('epoch_size is: ', epoch_size)
  97. model.train(epoch_size, ds_train,callbacks=[time_cb, ckpoint_cb,LossMonitor()])
  98. ###上传训练结果到启智平台,注意必须将要输出的模型存储在c2net_context.output_path
  99. upload_output()

No Description