You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV3 train."""
  16. import os
  17. import time
  18. import argparse
  19. import datetime
  20. from mindspore import ParallelMode
  21. from mindspore.nn.optim.momentum import Momentum
  22. from mindspore import Tensor
  23. import mindspore.nn as nn
  24. from mindspore import context
  25. from mindspore.communication.management import init, get_rank, get_group_size
  26. from mindspore.train.callback import ModelCheckpoint, RunContext
  27. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  28. import mindspore as ms
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
  31. from src.logger import get_logger
  32. from src.util import AverageMeter, load_backbone, get_param_groups
  33. from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \
  34. warmup_cosine_annealing_lr_V2, warmup_cosine_annealing_lr_sample
  35. from src.yolo_dataset import create_yolo_dataset
  36. from src.initializer import default_recurisive_init
  37. from src.config import ConfigYOLOV3DarkNet53
  38. from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
  39. from src.util import ShapeRecord
  40. devid = int(os.getenv('DEVICE_ID'))
  41. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
  42. device_target="Ascend", save_graphs=True, device_id=devid)
  43. class BuildTrainNetwork(nn.Cell):
  44. def __init__(self, network, criterion):
  45. super(BuildTrainNetwork, self).__init__()
  46. self.network = network
  47. self.criterion = criterion
  48. def construct(self, input_data, label):
  49. output = self.network(input_data)
  50. loss = self.criterion(output, label)
  51. return loss
  52. def parse_args():
  53. """Parse train arguments."""
  54. parser = argparse.ArgumentParser('mindspore coco training')
  55. # dataset related
  56. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  57. parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per gpu')
  58. # network related
  59. parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
  60. ' model to load')
  61. parser.add_argument('--resume_yolov3', default='', type=str, help='path of pretrained yolov3')
  62. # optimizer and lr related
  63. parser.add_argument('--lr_scheduler', default='exponential', type=str,
  64. help='lr-scheduler, option type: exponential, cosine_annealing')
  65. parser.add_argument('--lr', default=0.001, type=float, help='learning rate of the training')
  66. parser.add_argument('--lr_epochs', type=str, default='220,250', help='epoch of lr changing')
  67. parser.add_argument('--lr_gamma', type=float, default=0.1,
  68. help='decrease lr by a factor of exponential lr_scheduler')
  69. parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
  70. parser.add_argument('--T_max', type=int, default=320, help='T-max in cosine_annealing scheduler')
  71. parser.add_argument('--max_epoch', type=int, default=320, help='max epoch num to train the model')
  72. parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
  73. parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
  74. parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  75. # loss related
  76. parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
  77. parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
  78. parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
  79. # logging related
  80. parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
  81. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
  82. parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
  83. parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
  84. # distributed related
  85. parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
  86. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  87. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  88. # roma obs
  89. parser.add_argument('--train_url', type=str, default="", help='train url')
  90. # profiler init
  91. parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
  92. # reset default config
  93. parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
  94. parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
  95. args, _ = parser.parse_known_args()
  96. if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max:
  97. args.T_max = args.max_epoch
  98. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  99. args.data_root = os.path.join(args.data_dir, 'train2014')
  100. args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
  101. return args
  102. def conver_training_shape(args):
  103. training_shape = [int(args.training_shape), int(args.training_shape)]
  104. return training_shape
  105. def train():
  106. """Train function."""
  107. args = parse_args()
  108. # init distributed
  109. if args.is_distributed:
  110. init()
  111. args.rank = get_rank()
  112. args.group_size = get_group_size()
  113. # select for master rank save ckpt or all rank save, compatiable for model parallel
  114. args.rank_save_ckpt_flag = 0
  115. if args.is_save_on_master:
  116. if args.rank == 0:
  117. args.rank_save_ckpt_flag = 1
  118. else:
  119. args.rank_save_ckpt_flag = 1
  120. # logger
  121. args.outputs_dir = os.path.join(args.ckpt_path,
  122. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  123. args.logger = get_logger(args.outputs_dir, args.rank)
  124. args.logger.save_args(args)
  125. if args.need_profiler:
  126. from mindinsight.profiler.profiling import Profiler
  127. profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
  128. loss_meter = AverageMeter('loss')
  129. context.reset_auto_parallel_context()
  130. if args.is_distributed:
  131. parallel_mode = ParallelMode.DATA_PARALLEL
  132. degree = get_group_size()
  133. else:
  134. parallel_mode = ParallelMode.STAND_ALONE
  135. degree = 1
  136. context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree)
  137. network = YOLOV3DarkNet53(is_training=True)
  138. # default is kaiming-normal
  139. default_recurisive_init(network)
  140. if args.pretrained_backbone:
  141. network = load_backbone(network, args.pretrained_backbone, args)
  142. args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
  143. else:
  144. args.logger.info('Not load pre-trained backbone, please be careful')
  145. if args.resume_yolov3:
  146. param_dict = load_checkpoint(args.resume_yolov3)
  147. param_dict_new = {}
  148. for key, values in param_dict.items():
  149. if key.startswith('moments.'):
  150. continue
  151. elif key.startswith('yolo_network.'):
  152. param_dict_new[key[13:]] = values
  153. args.logger.info('in resume {}'.format(key))
  154. else:
  155. param_dict_new[key] = values
  156. args.logger.info('in resume {}'.format(key))
  157. args.logger.info('resume finished')
  158. load_param_into_net(network, param_dict_new)
  159. args.logger.info('load_model {} success'.format(args.resume_yolov3))
  160. network = YoloWithLossCell(network)
  161. args.logger.info('finish get network')
  162. config = ConfigYOLOV3DarkNet53()
  163. config.label_smooth = args.label_smooth
  164. config.label_smooth_factor = args.label_smooth_factor
  165. if args.training_shape:
  166. config.multi_scale = [conver_training_shape(args)]
  167. if args.resize_rate:
  168. config.resize_rate = args.resize_rate
  169. ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
  170. batch_size=args.per_batch_size, max_epoch=args.max_epoch,
  171. device_num=args.group_size, rank=args.rank, config=config)
  172. args.logger.info('Finish loading dataset')
  173. args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)
  174. if not args.ckpt_interval:
  175. args.ckpt_interval = args.steps_per_epoch
  176. # lr scheduler
  177. if args.lr_scheduler == 'exponential':
  178. lr = warmup_step_lr(args.lr,
  179. args.lr_epochs,
  180. args.steps_per_epoch,
  181. args.warmup_epochs,
  182. args.max_epoch,
  183. gamma=args.lr_gamma,
  184. )
  185. elif args.lr_scheduler == 'cosine_annealing':
  186. lr = warmup_cosine_annealing_lr(args.lr,
  187. args.steps_per_epoch,
  188. args.warmup_epochs,
  189. args.max_epoch,
  190. args.T_max,
  191. args.eta_min)
  192. elif args.lr_scheduler == 'cosine_annealing_V2':
  193. lr = warmup_cosine_annealing_lr_V2(args.lr,
  194. args.steps_per_epoch,
  195. args.warmup_epochs,
  196. args.max_epoch,
  197. args.T_max,
  198. args.eta_min)
  199. elif args.lr_scheduler == 'cosine_annealing_sample':
  200. lr = warmup_cosine_annealing_lr_sample(args.lr,
  201. args.steps_per_epoch,
  202. args.warmup_epochs,
  203. args.max_epoch,
  204. args.T_max,
  205. args.eta_min)
  206. else:
  207. raise NotImplementedError(args.lr_scheduler)
  208. opt = Momentum(params=get_param_groups(network),
  209. learning_rate=Tensor(lr),
  210. momentum=args.momentum,
  211. weight_decay=args.weight_decay,
  212. loss_scale=args.loss_scale)
  213. network = TrainingWrapper(network, opt)
  214. network.set_train()
  215. if args.rank_save_ckpt_flag:
  216. # checkpoint save
  217. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  218. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
  219. keep_checkpoint_max=ckpt_max_num)
  220. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  221. directory=args.outputs_dir,
  222. prefix='{}'.format(args.rank))
  223. cb_params = _InternalCallbackParam()
  224. cb_params.train_network = network
  225. cb_params.epoch_num = ckpt_max_num
  226. cb_params.cur_epoch_num = 1
  227. run_context = RunContext(cb_params)
  228. ckpt_cb.begin(run_context)
  229. old_progress = -1
  230. t_end = time.time()
  231. data_loader = ds.create_dict_iterator()
  232. shape_record = ShapeRecord()
  233. for i, data in enumerate(data_loader):
  234. images = data["image"]
  235. input_shape = images.shape[2:4]
  236. args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
  237. shape_record.set(input_shape)
  238. images = Tensor(images)
  239. annos = data["annotation"]
  240. if args.group_size == 1:
  241. batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
  242. batch_preprocess_true_box(annos, config, input_shape)
  243. else:
  244. batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
  245. batch_preprocess_true_box_single(annos, config, input_shape)
  246. batch_y_true_0 = Tensor(batch_y_true_0)
  247. batch_y_true_1 = Tensor(batch_y_true_1)
  248. batch_y_true_2 = Tensor(batch_y_true_2)
  249. batch_gt_box0 = Tensor(batch_gt_box0)
  250. batch_gt_box1 = Tensor(batch_gt_box1)
  251. batch_gt_box2 = Tensor(batch_gt_box2)
  252. input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
  253. loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
  254. batch_gt_box2, input_shape)
  255. loss_meter.update(loss.asnumpy())
  256. if args.rank_save_ckpt_flag:
  257. # ckpt progress
  258. cb_params.cur_step_num = i + 1 # current step number
  259. cb_params.batch_num = i + 2
  260. ckpt_cb.step_end(run_context)
  261. if i % args.log_interval == 0:
  262. time_used = time.time() - t_end
  263. epoch = int(i / args.steps_per_epoch)
  264. fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
  265. if args.rank == 0:
  266. args.logger.info(
  267. 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
  268. t_end = time.time()
  269. loss_meter.reset()
  270. old_progress = i
  271. if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
  272. cb_params.cur_epoch_num += 1
  273. if args.need_profiler:
  274. if i == 10:
  275. profiler.analyse()
  276. break
  277. args.logger.info('==========end training===============')
  278. if __name__ == "__main__":
  279. train()