You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 13 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV4 train."""
  16. import os
  17. import time
  18. import argparse
  19. import datetime
  20. import ast
  21. from mindspore.context import ParallelMode
  22. from mindspore.nn.optim.momentum import Momentum
  23. from mindspore import Tensor
  24. import mindspore.nn as nn
  25. from mindspore import context
  26. from mindspore.communication.management import init, get_rank, get_group_size
  27. from mindspore.train.callback import ModelCheckpoint, RunContext
  28. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  29. import mindspore as ms
  30. from mindspore import amp
  31. from mindspore.train.loss_scale_manager import FixedLossScaleManager
  32. from mindspore.common import set_seed
  33. from mindspore.profiler.profiling import Profiler
  34. from src.yolo import YOLOV4CspDarkNet53, YoloWithLossCell, TrainingWrapper
  35. from src.logger import get_logger
  36. from src.util import AverageMeter, get_param_groups
  37. from src.lr_scheduler import get_lr
  38. from src.yolo_dataset import create_yolo_dataset
  39. from src.initializer import default_recurisive_init, load_yolov4_params
  40. from src.config import ConfigYOLOV4CspDarkNet53
  41. from src.util import keep_loss_fp32
  42. set_seed(1)
  43. parser = argparse.ArgumentParser('mindspore coco training')
  44. # device related
  45. parser.add_argument('--device_target', type=str, default='Ascend',
  46. help='device where the code will be implemented. (Default: Ascend)')
  47. # dataset related
  48. parser.add_argument('--data_dir', type=str, help='Train dataset directory.')
  49. parser.add_argument('--per_batch_size', default=8, type=int, help='Batch size for Training. Default: 8.')
  50. # network related
  51. parser.add_argument('--pretrained_backbone', default='', type=str,
  52. help='The ckpt file of CspDarkNet53. Default: "".')
  53. parser.add_argument('--resume_yolov4', default='', type=str,
  54. help='The ckpt file of YOLOv4, which used to fine tune. Default: ""')
  55. parser.add_argument('--pretrained_checkpoint', default='', type=str,
  56. help='The ckpt file of YoloV4CspDarkNet53. Default: "".')
  57. parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
  58. help="Filter the last weight parameters, default is False.")
  59. # optimizer and lr related
  60. parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str,
  61. help='Learning rate scheduler, options: exponential, cosine_annealing. Default: exponential')
  62. parser.add_argument('--lr', default=0.012, type=float, help='Learning rate. Default: 0.001')
  63. parser.add_argument('--lr_epochs', type=str, default='220,250',
  64. help='Epoch of changing of lr changing, split with ",". Default: 220,250')
  65. parser.add_argument('--lr_gamma', type=float, default=0.1,
  66. help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
  67. parser.add_argument('--eta_min', type=float, default=0., help='Eta_min in cosine_annealing scheduler. Default: 0')
  68. parser.add_argument('--t_max', type=int, default=320, help='T-max in cosine_annealing scheduler. Default: 320')
  69. parser.add_argument('--max_epoch', type=int, default=320, help='Max epoch num to train the model. Default: 320')
  70. parser.add_argument('--warmup_epochs', default=20, type=float, help='Warmup epochs. Default: 0')
  71. parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay factor. Default: 0.0005')
  72. parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9')
  73. # loss related
  74. parser.add_argument('--loss_scale', type=int, default=64, help='Static loss scale. Default: 1024')
  75. parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default:0')
  76. parser.add_argument('--label_smooth_factor', type=float, default=0.1,
  77. help='Smooth strength of original one-hot. Default: 0.1')
  78. # logging related
  79. parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
  80. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
  81. parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
  82. parser.add_argument('--is_save_on_master', type=int, default=1,
  83. help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')
  84. # distributed related
  85. parser.add_argument('--is_distributed', type=int, default=1,
  86. help='Distribute train or not, 1 for yes, 0 for no. Default: 1')
  87. parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed. Default: 0')
  88. parser.add_argument('--group_size', type=int, default=1, help='World size of device. Default: 1')
  89. # profiler init
  90. parser.add_argument('--need_profiler', type=int, default=0,
  91. help='Whether use profiler. 0 for no, 1 for yes. Default: 0')
  92. # reset default config
  93. parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""')
  94. parser.add_argument('--resize_rate', type=int, default=10,
  95. help='Resize rate for multi-scale training. Default: None')
  96. args, _ = parser.parse_known_args()
  97. if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
  98. args.t_max = args.max_epoch
  99. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  100. args.data_root = os.path.join(args.data_dir, 'train2017')
  101. args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2017.json')
  102. device_id = int(os.getenv('DEVICE_ID', '0'))
  103. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
  104. device_target=args.device_target, save_graphs=False, device_id=device_id)
  105. if args.need_profiler:
  106. profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
  107. # init distributed
  108. if args.is_distributed:
  109. if args.device_target == "Ascend":
  110. init()
  111. else:
  112. init("nccl")
  113. args.rank = get_rank()
  114. args.group_size = get_group_size()
  115. # select for master rank save ckpt or all rank save, compatible for model parallel
  116. args.rank_save_ckpt_flag = 0
  117. if args.is_save_on_master:
  118. if args.rank == 0:
  119. args.rank_save_ckpt_flag = 1
  120. else:
  121. args.rank_save_ckpt_flag = 1
  122. # logger
  123. args.outputs_dir = os.path.join(args.ckpt_path,
  124. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  125. args.logger = get_logger(args.outputs_dir, args.rank)
  126. args.logger.save_args(args)
  127. def convert_training_shape(args_training_shape):
  128. training_shape = [int(args_training_shape), int(args_training_shape)]
  129. return training_shape
  130. class BuildTrainNetwork(nn.Cell):
  131. def __init__(self, network_, criterion):
  132. super(BuildTrainNetwork, self).__init__()
  133. self.network = network_
  134. self.criterion = criterion
  135. def construct(self, input_data, label):
  136. output = self.network(input_data)
  137. loss_ = self.criterion(output, label)
  138. return loss_
  139. if __name__ == "__main__":
  140. loss_meter = AverageMeter('loss')
  141. context.reset_auto_parallel_context()
  142. parallel_mode = ParallelMode.STAND_ALONE
  143. degree = 1
  144. if args.is_distributed:
  145. parallel_mode = ParallelMode.DATA_PARALLEL
  146. degree = get_group_size()
  147. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
  148. network = YOLOV4CspDarkNet53(is_training=True)
  149. # default is kaiming-normal
  150. config = ConfigYOLOV4CspDarkNet53()
  151. args.checkpoint_filter_list = config.checkpoint_filter_list
  152. default_recurisive_init(network)
  153. load_yolov4_params(args, network)
  154. network = YoloWithLossCell(network)
  155. args.logger.info('finish get network')
  156. config.label_smooth = args.label_smooth
  157. config.label_smooth_factor = args.label_smooth_factor
  158. if args.training_shape:
  159. config.multi_scale = [convert_training_shape(args.training_shape)]
  160. if args.resize_rate:
  161. config.resize_rate = args.resize_rate
  162. ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
  163. batch_size=args.per_batch_size, max_epoch=args.max_epoch,
  164. device_num=args.group_size, rank=args.rank, config=config)
  165. args.logger.info('Finish loading dataset')
  166. args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)
  167. if not args.ckpt_interval:
  168. args.ckpt_interval = args.steps_per_epoch
  169. lr = get_lr(args)
  170. opt = Momentum(params=get_param_groups(network),
  171. learning_rate=Tensor(lr),
  172. momentum=args.momentum,
  173. weight_decay=args.weight_decay,
  174. loss_scale=args.loss_scale)
  175. is_gpu = context.get_context("device_target") == "GPU"
  176. if is_gpu:
  177. loss_scale_value = 1.0
  178. loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False)
  179. network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale,
  180. level="O2", keep_batchnorm_fp32=False)
  181. keep_loss_fp32(network)
  182. else:
  183. network = TrainingWrapper(network, opt)
  184. network.set_train()
  185. if args.rank_save_ckpt_flag:
  186. # checkpoint save
  187. ckpt_max_num = 10
  188. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
  189. keep_checkpoint_max=ckpt_max_num)
  190. save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
  191. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  192. directory=save_ckpt_path,
  193. prefix='{}'.format(args.rank))
  194. cb_params = _InternalCallbackParam()
  195. cb_params.train_network = network
  196. cb_params.epoch_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  197. cb_params.cur_epoch_num = 1
  198. run_context = RunContext(cb_params)
  199. ckpt_cb.begin(run_context)
  200. old_progress = -1
  201. t_end = time.time()
  202. data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
  203. for i, data in enumerate(data_loader):
  204. images = data["image"]
  205. input_shape = images.shape[2:4]
  206. args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
  207. images = Tensor.from_numpy(images)
  208. batch_y_true_0 = Tensor.from_numpy(data['bbox1'])
  209. batch_y_true_1 = Tensor.from_numpy(data['bbox2'])
  210. batch_y_true_2 = Tensor.from_numpy(data['bbox3'])
  211. batch_gt_box0 = Tensor.from_numpy(data['gt_box1'])
  212. batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
  213. batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])
  214. input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
  215. loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
  216. batch_gt_box2, input_shape)
  217. loss_meter.update(loss.asnumpy())
  218. if args.rank_save_ckpt_flag:
  219. # ckpt progress
  220. cb_params.cur_step_num = i + 1 # current step number
  221. cb_params.batch_num = i + 2
  222. ckpt_cb.step_end(run_context)
  223. if i % args.log_interval == 0:
  224. time_used = time.time() - t_end
  225. epoch = int(i / args.steps_per_epoch)
  226. fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
  227. if args.rank == 0:
  228. args.logger.info(
  229. 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
  230. t_end = time.time()
  231. loss_meter.reset()
  232. old_progress = i
  233. if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
  234. cb_params.cur_epoch_num += 1
  235. if args.need_profiler:
  236. if i == 10:
  237. profiler.analyse()
  238. break
  239. args.logger.info('==========end training===============')