You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_continue.py 5.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #####################################################################################################
  2. # 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果
  3. #
  4. # 示例用法
  5. # - 增加两个训练参数
  6. # 'ckpt_save_name' 此次任务的输出文件名,用于保存此次训练的模型文件名称(不带后缀)
  7. # 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称(不带后缀),首次训练默认为空,则不读取任何文件
  8. # - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务
  9. #####################################################################################################
  10. import os
  11. import argparse
  12. from config import mnist_cfg as cfg
  13. from dataset import create_dataset
  14. from dataset_distributed import create_dataset_parallel
  15. from lenet import LeNet5
  16. import mindspore.nn as nn
  17. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
  18. from mindspore import load_checkpoint, load_param_into_net
  19. from mindspore.train import Model
  20. from mindspore.nn.metrics import Accuracy
  21. from mindspore.communication.management import get_rank
  22. #导入openi包
  23. from openi.context import prepare, upload_openi
  24. from openi.context.helper import obs_copy_file, obs_copy_folder
  25. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  26. parser.add_argument(
  27. '--device_target',
  28. type=str,
  29. default="Ascend",
  30. choices=['Ascend', 'CPU'],
  31. help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
  32. parser.add_argument('--epoch_size',
  33. type=int,
  34. default=5,
  35. help='Training epochs.')
  36. ### continue task parameters
  37. parser.add_argument('--ckpt_load_name',
  38. help='model name to save/load',
  39. default= '')
  40. parser.add_argument('--ckpt_save_name',
  41. help='model name to save/load',
  42. default= 'checkpoint')
  43. if __name__ == "__main__":
  44. args, unknown = parser.parse_known_args()
  45. ###Initialize and copy data to training image
  46. openi_context = prepare()
  47. data_dir = openi_context.dataset_path
  48. pretrain_model_dir = openi_context.pretrain_model_path
  49. train_dir = openi_context.output_path
  50. device_num = int(os.getenv('RANK_SIZE'))
  51. ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
  52. if ds_train.get_dataset_size() == 0:
  53. raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
  54. network = LeNet5(cfg.num_classes)
  55. net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  56. net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
  57. time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
  58. ### 继续训练模型加载
  59. if args.ckpt_load_name:
  60. obs_copy_folder(args.train_url, base_path)
  61. load_path = "{}/{}.ckpt".format(base_path,args.ckpt_load_name)
  62. param_dict = load_checkpoint(load_path)
  63. load_param_into_net(network, param_dict)
  64. print("Successfully load ckpt file:{}, saved_net_work:{}".format(load_path,param_dict))
  65. ### 保存已有模型名避免重复回传结果
  66. outputFiles = os.listdir(base_path)
  67. if args.device_target != "Ascend":
  68. model = Model(network,
  69. net_loss,
  70. net_opt,
  71. metrics={"accuracy": Accuracy()})
  72. else:
  73. model = Model(network,
  74. net_loss,
  75. net_opt,
  76. metrics={"accuracy": Accuracy()},
  77. amp_level="O2")
  78. config_ck = CheckpointConfig(
  79. save_checkpoint_steps=cfg.save_checkpoint_steps,
  80. keep_checkpoint_max=cfg.keep_checkpoint_max)
  81. #Note that this method saves the model file on each card. You need to specify the save path on each card.
  82. # In this example, get_rank() is added to distinguish different paths.
  83. if device_num == 1:
  84. save_path = base_path + "/"
  85. if device_num > 1:
  86. save_path = base_path + "/" + str(get_rank()) + "/"
  87. ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name,
  88. directory=save_path,
  89. config=config_ck)
  90. print("============== Starting Training ==============")
  91. epoch_size = cfg['epoch_size']
  92. if (args.epoch_size):
  93. epoch_size = args.epoch_size
  94. print('epoch_size is: ', epoch_size)
  95. model.train(epoch_size,
  96. ds_train,
  97. callbacks=[time_cb, ckpoint_cb,
  98. LossMonitor()])
  99. ### 将训练容器中的新输出模型 回传到启智社区
  100. outputFilesNew = os.listdir(base_path)
  101. new_models = [i for i in outputFilesNew if i not in outputFiles]
  102. for n in new_models:
  103. ckpt_url = base_path + "/" + n
  104. obs_ckpt_url = args.train_url + "/" + n
  105. obs_copy_file(ckpt_url, obs_ckpt_url)

No Description