You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Train CenterNet and get network model files(.ckpt)
  17. """
  18. import os
  19. import argparse
  20. import mindspore.communication.management as D
  21. from mindspore.communication.management import get_rank
  22. from mindspore import context
  23. from mindspore.train.model import Model
  24. from mindspore.context import ParallelMode
  25. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from mindspore.nn.optim import Adam
  28. from mindspore import log as logger
  29. from mindspore.common import set_seed
  30. from mindspore.profiler import Profiler
  31. from src.dataset import COCOHP
  32. from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
  33. from src import CenterNetWithoutLossScaleCell
  34. from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
  35. from src.config import dataset_config, net_config, train_config
  36. _current_dir = os.path.dirname(os.path.realpath(__file__))
  37. parser = argparse.ArgumentParser(description='CenterNet training')
  38. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'],
  39. help='device where the code will be implemented. (Default: Ascend)')
  40. parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
  41. help="Run distribute, default is false.")
  42. parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
  43. help="Profiling to parsing runtime info, default is false.")
  44. parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
  45. parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
  46. parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
  47. "i.e. run all steps according to epoch number.")
  48. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  49. parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
  50. parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
  51. help="Enable save checkpoint, default is true.")
  52. parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
  53. help="Enable shuffle for dataset, default is true.")
  54. parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
  55. help="Enable data sink, default is true.")
  56. parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
  57. parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
  58. parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  59. parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
  60. parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
  61. parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset files directory")
  62. parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind",
  63. help="Prefix of MindRecord dataset filename.")
  64. parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
  65. parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
  66. args_opt = parser.parse_args()
  67. def _set_parallel_all_reduce_split():
  68. """set centernet all_reduce fusion split"""
  69. if net_config.last_level == 5:
  70. context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
  71. elif net_config.last_level == 6:
  72. context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
  73. else:
  74. raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
  75. "please re-split after known the true value".format(net_config.last_level))
  76. def _get_params_groups(network, optimizer):
  77. """
  78. Get param groups
  79. """
  80. params = network.trainable_params()
  81. decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
  82. other_params = list(filter(optimizer.decay_filter, params))
  83. group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
  84. {'params': other_params, 'weight_decay': 0.0},
  85. {'order_params': params}]
  86. return group_params
  87. def _get_optimizer(network, dataset_size):
  88. """get optimizer, only support Adam right now."""
  89. if train_config.optimizer == 'Adam':
  90. group_params = _get_params_groups(network, train_config.Adam)
  91. if train_config.lr_schedule == "PolyDecay":
  92. lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
  93. end_learning_rate=train_config.PolyDecay.end_learning_rate,
  94. warmup_steps=train_config.PolyDecay.warmup_steps,
  95. decay_steps=args_opt.train_steps,
  96. power=train_config.PolyDecay.power)
  97. optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
  98. elif train_config.lr_schedule == "MultiDecay":
  99. multi_epochs = train_config.MultiDecay.multi_epochs
  100. if not isinstance(multi_epochs, (list, tuple)):
  101. raise TypeError("multi_epochs must be list or tuple.")
  102. if not multi_epochs:
  103. multi_epochs = [args_opt.epoch_size]
  104. lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
  105. warmup_steps=train_config.MultiDecay.warmup_steps,
  106. multi_epochs=multi_epochs,
  107. steps_per_epoch=dataset_size,
  108. factor=train_config.MultiDecay.factor)
  109. optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
  110. else:
  111. raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
  112. format(train_config.optimizer))
  113. else:
  114. raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
  115. format(train_config.optimizer))
  116. return optimizer
  117. def train():
  118. """training CenterNet"""
  119. context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
  120. context.set_context(reserve_class_name_in_scope=False)
  121. context.set_context(save_graphs=False)
  122. ckpt_save_dir = args_opt.save_checkpoint_path
  123. rank = 0
  124. device_num = 1
  125. num_workers = 8
  126. if args_opt.device_target == "Ascend":
  127. context.set_context(enable_auto_mixed_precision=False)
  128. context.set_context(device_id=args_opt.device_id)
  129. if args_opt.distribute == "true":
  130. D.init()
  131. device_num = args_opt.device_num
  132. rank = args_opt.device_id % device_num
  133. ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
  134. context.reset_auto_parallel_context()
  135. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
  136. device_num=device_num)
  137. _set_parallel_all_reduce_split()
  138. else:
  139. args_opt.distribute = "false"
  140. args_opt.need_profiler = "false"
  141. args_opt.enable_data_sink = "false"
  142. # Start create dataset!
  143. # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
  144. logger.info("Begin creating dataset for CenterNet")
  145. coco = COCOHP(dataset_config, run_mode="train", net_opt=net_config,
  146. enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
  147. dataset = coco.create_train_dataset(args_opt.mindrecord_dir, args_opt.mindrecord_prefix,
  148. batch_size=train_config.batch_size, device_num=device_num, rank=rank,
  149. num_parallel_workers=num_workers, do_shuffle=args_opt.do_shuffle == 'true')
  150. dataset_size = dataset.get_dataset_size()
  151. logger.info("Create dataset done!")
  152. net_with_loss = CenterNetMultiPoseLossCell(net_config)
  153. new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
  154. if args_opt.train_steps > 0:
  155. new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
  156. else:
  157. args_opt.train_steps = args_opt.epoch_size * dataset_size
  158. logger.info("train steps: {}".format(args_opt.train_steps))
  159. optimizer = _get_optimizer(net_with_loss, dataset_size)
  160. enable_static_time = args_opt.device_target == "CPU"
  161. callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)]
  162. if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
  163. config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
  164. keep_checkpoint_max=args_opt.save_checkpoint_num)
  165. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
  166. directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
  167. callback.append(ckpoint_cb)
  168. if args_opt.load_checkpoint_path:
  169. param_dict = load_checkpoint(args_opt.load_checkpoint_path)
  170. load_param_into_net(net_with_loss, param_dict)
  171. if args_opt.device_target == "Ascend":
  172. net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
  173. sens=train_config.loss_scale_value)
  174. else:
  175. net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer)
  176. model = Model(net_with_grads)
  177. model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
  178. sink_size=args_opt.data_sink_steps)
  179. if __name__ == '__main__':
  180. if args_opt.need_profiler == "true":
  181. profiler = Profiler(output_path=args_opt.profiler_path)
  182. set_seed(0)
  183. train()
  184. if args_opt.need_profiler == "true":
  185. profiler.analyse()