You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net.py 4.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """
  2. ######################## train lenet dataparallel example ########################
  3. train lenet and get network model files(.ckpt)
  4. The training of the intelligent computing network currently supports single dataset training, and does not require
  5. the obs copy process.It only needs to define two parameters and then call it directly:
  6. train_dir = '/cache/output' #The location of the output
  7. data_dir = '/cache/dataset' #The location of the dataset
  8. """
  9. import os
  10. import argparse
  11. from dataset import create_dataset
  12. from dataset_distributed import create_dataset_parallel
  13. import moxing as mox
  14. from config import mnist_cfg as cfg
  15. from lenet import LeNet5
  16. import mindspore.nn as nn
  17. from mindspore import context
  18. from mindspore.common import set_seed
  19. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  20. from mindspore.train import Model
  21. from mindspore.nn.metrics import Accuracy
  22. from mindspore.context import ParallelMode
  23. from mindspore.communication.management import init, get_rank, get_group_size
  24. import mindspore.ops as ops
  25. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  26. parser.add_argument(
  27. '--device_target',
  28. type=str,
  29. default="Ascend",
  30. choices=['Ascend', 'CPU'],
  31. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  32. parser.add_argument('--epoch_size',
  33. type=int,
  34. default=5,
  35. help='Training epochs.')
  36. if __name__ == "__main__":
  37. args = parser.parse_args()
  38. ###define two parameters and then call it directly###
  39. data_dir = '/cache/dataset'
  40. train_dir = '/cache/output'
  41. device_num = int(os.getenv('RANK_SIZE'))
  42. if device_num == 1:
  43. context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
  44. ds_train = create_dataset(os.path.join(data_dir, "train"), cfg.batch_size)
  45. if device_num > 1:
  46. # set device_id and init for multi-card training
  47. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
  48. context.reset_auto_parallel_context()
  49. context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
  50. init()
  51. ds_train = create_dataset_parallel(os.path.join(data_dir, "train"), cfg.batch_size)
  52. if ds_train.get_dataset_size() == 0:
  53. raise ValueError(
  54. "Please check dataset size > 0 and batch_size <= dataset size")
  55. network = LeNet5(cfg.num_classes)
  56. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  57. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  58. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  59. if args.device_target != "Ascend":
  60. model = Model(network,
  61. net_loss,
  62. net_opt,
  63. metrics={"accuracy": Accuracy()})
  64. else:
  65. model = Model(network,
  66. net_loss,
  67. net_opt,
  68. metrics={"accuracy": Accuracy()},
  69. amp_level="O2")
  70. config_ck = CheckpointConfig(
  71. save_checkpoint_steps=cfg.save_checkpoint_steps,
  72. keep_checkpoint_max=cfg.keep_checkpoint_max)
  73. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  74. # In the example, get_rank() is added to distinguish different paths.
  75. if device_num == 1:
  76. outputDirectory = train_dir + "/"
  77. if device_num > 1:
  78. outputDirectory = train_dir + "/" + str(get_rank()) + "/"
  79. ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
  80. directory=outputDirectory,
  81. config=config_ck)
  82. print("============== Starting Training ==============")
  83. epoch_size = cfg['epoch_size']
  84. if (args.epoch_size):
  85. epoch_size = args.epoch_size
  86. print('epoch_size is: ', epoch_size)
  87. model.train(epoch_size,ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], dataset_sink_mode=False)