You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net_dataparallel.py 3.9 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """
  2. ######################## train lenet dataparallel example ########################
  3. train lenet and get network model files(.ckpt)
  4. The training of the intelligent computing network currently supports single dataset training, and does not require
  5. the obs copy process.It only needs to define two parameters and then call it directly:
  6. train_dir = '/cache/output' #The location of the output
  7. data_dir = '/cache/dataset' #The location of the dataset
  8. """
  9. import os
  10. import argparse
  11. from dataset_distributed import create_dataset_parallel
  12. import moxing as mox
  13. from config import mnist_cfg as cfg
  14. from lenet import LeNet5
  15. import mindspore.nn as nn
  16. from mindspore import context
  17. from mindspore.common import set_seed
  18. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  19. from mindspore.train import Model
  20. from mindspore.nn.metrics import Accuracy
  21. from mindspore.context import ParallelMode
  22. from mindspore.communication.management import init, get_rank, get_group_size
  23. import mindspore.ops as ops
  24. # set device_id and init
  25. device_id = int(os.getenv('DEVICE_ID'))
  26. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  27. context.set_context(device_id=device_id)
  28. init()
  29. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  30. parser.add_argument(
  31. '--device_target',
  32. type=str,
  33. default="Ascend",
  34. choices=['Ascend', 'CPU'],
  35. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  36. parser.add_argument('--epoch_size',
  37. type=int,
  38. default=5,
  39. help='Training epochs.')
  40. set_seed(114514)
  41. if __name__ == "__main__":
  42. args = parser.parse_args()
  43. ###define two parameters and then call it directly###
  44. train_dir = '/cache/output'
  45. data_dir = '/cache/dataset'
  46. context.reset_auto_parallel_context()
  47. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
  48. ds_train = create_dataset_parallel(os.path.join(data_dir, "train"),
  49. cfg.batch_size)
  50. if ds_train.get_dataset_size() == 0:
  51. raise ValueError(
  52. "Please check dataset size > 0 and batch_size <= dataset size")
  53. network = LeNet5(cfg.num_classes)
  54. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  55. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  56. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  57. if args.device_target != "Ascend":
  58. model = Model(network,
  59. net_loss,
  60. net_opt,
  61. metrics={"accuracy": Accuracy()})
  62. else:
  63. model = Model(network,
  64. net_loss,
  65. net_opt,
  66. metrics={"accuracy": Accuracy()},
  67. amp_level="O2")
  68. config_ck = CheckpointConfig(
  69. save_checkpoint_steps=cfg.save_checkpoint_steps,
  70. keep_checkpoint_max=cfg.keep_checkpoint_max)
  71. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  72. # In the example, get_rank() is added to distinguish different paths.
  73. ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
  74. directory=train_dir + "/" + str(get_rank()) + "/",
  75. config=config_ck)
  76. print("============== Starting Training ==============")
  77. epoch_size = cfg['epoch_size']
  78. if (args.epoch_size):
  79. epoch_size = args.epoch_size
  80. print('epoch_size is: ', epoch_size)
  81. model.train(epoch_size,ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], dataset_sink_mode=False)