You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face detection train."""
  16. import os
  17. import ast
  18. import time
  19. import datetime
  20. import argparse
  21. import numpy as np
  22. from mindspore import context
  23. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  24. from mindspore import Tensor
  25. from mindspore.communication.management import init, get_rank, get_group_size
  26. from mindspore.context import ParallelMode
  27. from mindspore.train.callback import ModelCheckpoint, RunContext
  28. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  29. from mindspore.common import dtype as mstype
  30. from src.logging import get_logger
  31. from src.data_preprocess import create_dataset
  32. from src.config import config
  33. from src.network_define import define_network
  34. def parse_args():
  35. '''parse_args'''
  36. parser = argparse.ArgumentParser('Yolov3 Face Detection')
  37. parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "CPU"),
  38. help="run platform, support Ascend and CPU.")
  39. parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
  40. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  41. parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
  42. parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
  43. parser.add_argument("--use_loss_scale", type=ast.literal_eval, default=True,
  44. help="Whether use dynamic loss scale, default is True.")
  45. args, _ = parser.parse_known_args()
  46. args.batch_size = config.batch_size
  47. args.warmup_lr = config.warmup_lr
  48. args.lr_rates = config.lr_rates
  49. if args.run_platform == "CPU":
  50. args.use_loss_scale = False
  51. args.world_size = 1
  52. args.local_rank = 0
  53. if args.world_size != 8:
  54. args.lr_steps = [i * 8 // args.world_size for i in config.lr_steps]
  55. else:
  56. args.lr_steps = config.lr_steps
  57. args.gamma = config.gamma
  58. args.weight_decay = config.weight_decay if args.world_size != 1 else 0.
  59. args.momentum = config.momentum
  60. args.max_epoch = config.max_epoch
  61. args.log_interval = config.log_interval
  62. args.ckpt_path = config.ckpt_path
  63. args.ckpt_interval = config.ckpt_interval
  64. args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  65. print('args.outputs_dir', args.outputs_dir)
  66. args.num_classes = config.num_classes
  67. args.anchors = config.anchors
  68. args.anchors_mask = config.anchors_mask
  69. args.num_anchors_list = [len(x) for x in args.anchors_mask]
  70. return args
  71. def train(args):
  72. '''train'''
  73. print('=============yolov3 start trainging==================')
  74. devid = int(os.getenv('DEVICE_ID', '0')) if args.run_platform != 'CPU' else 0
  75. context.set_context(mode=context.GRAPH_MODE, device_target=args.run_platform, save_graphs=False, device_id=devid)
  76. # init distributed
  77. if args.world_size != 1:
  78. init()
  79. args.local_rank = get_rank()
  80. args.world_size = get_group_size()
  81. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, device_num=args.world_size,
  82. gradients_mean=True)
  83. args.logger = get_logger(args.outputs_dir, args.local_rank)
  84. # dataloader
  85. ds = create_dataset(args)
  86. args.logger.important_info('start create network')
  87. create_network_start = time.time()
  88. train_net = define_network(args)
  89. # checkpoint
  90. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  91. train_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
  92. ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
  93. cb_params = _InternalCallbackParam()
  94. cb_params.train_network = train_net
  95. cb_params.epoch_num = ckpt_max_num
  96. cb_params.cur_epoch_num = 1
  97. run_context = RunContext(cb_params)
  98. ckpt_cb.begin(run_context)
  99. train_net.set_train()
  100. t_end = time.time()
  101. t_epoch = time.time()
  102. old_progress = -1
  103. i = 0
  104. if args.use_loss_scale:
  105. scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_factor=2, scale_window=2000)
  106. for data in ds.create_tuple_iterator(output_numpy=True):
  107. batch_images = data[0]
  108. batch_labels = data[1]
  109. input_list = [Tensor(batch_images, mstype.float32)]
  110. for idx in range(2, 26):
  111. input_list.append(Tensor(data[idx], mstype.float32))
  112. if args.use_loss_scale:
  113. scaling_sens = Tensor(scale_manager.get_loss_scale(), dtype=mstype.float32)
  114. loss0, overflow, _ = train_net(*input_list, scaling_sens)
  115. overflow = np.all(overflow.asnumpy())
  116. if overflow:
  117. scale_manager.update_loss_scale(overflow)
  118. else:
  119. scale_manager.update_loss_scale(False)
  120. args.logger.info('rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, '
  121. 'batch_labels:{}'.format(args.local_rank, i, loss0, overflow, scaling_sens, args.lr[i],
  122. batch_images.shape, batch_labels.shape))
  123. else:
  124. loss0 = train_net(*input_list)
  125. args.logger.info('rank[{}], iter[{}], loss[{}], lr:{}, batch_images:{}, '
  126. 'batch_labels:{}'.format(args.local_rank, i, loss0, args.lr[i],
  127. batch_images.shape, batch_labels.shape))
  128. # save ckpt
  129. cb_params.cur_step_num = i + 1 # current step number
  130. cb_params.batch_num = i + 2
  131. if args.local_rank == 0:
  132. ckpt_cb.step_end(run_context)
  133. # save Log
  134. if i == 0:
  135. time_for_graph_compile = time.time() - create_network_start
  136. args.logger.important_info('Yolov3, graph compile time={:.2f}s'.format(time_for_graph_compile))
  137. if i % args.steps_per_epoch == 0:
  138. cb_params.cur_epoch_num += 1
  139. if i % args.log_interval == 0 and args.local_rank == 0:
  140. time_used = time.time() - t_end
  141. epoch = int(i / args.steps_per_epoch)
  142. fps = args.batch_size * (i - old_progress) * args.world_size / time_used
  143. args.logger.info('epoch[{}], iter[{}], loss:[{}], {:.2f} imgs/sec'.format(epoch, i, loss0, fps))
  144. t_end = time.time()
  145. old_progress = i
  146. if i % args.steps_per_epoch == 0 and args.local_rank == 0:
  147. epoch_time_used = time.time() - t_epoch
  148. epoch = int(i / args.steps_per_epoch)
  149. fps = args.batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
  150. args.logger.info('=================================================')
  151. args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
  152. args.logger.info('=================================================')
  153. t_epoch = time.time()
  154. i = i + 1
  155. args.logger.info('=============yolov3 training finished==================')
  156. if __name__ == "__main__":
  157. arg = parse_args()
  158. train(arg)