You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV3-Darknet53-Quant train."""
  16. import os
  17. import time
  18. import argparse
  19. import datetime
  20. from mindspore.context import ParallelMode
  21. from mindspore.nn.optim.momentum import Momentum
  22. from mindspore import Tensor
  23. from mindspore import context
  24. from mindspore.communication.management import init, get_rank, get_group_size
  25. from mindspore.train.callback import ModelCheckpoint, RunContext
  26. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  27. import mindspore as ms
  28. from mindspore.compression.quant import QuantizationAwareTraining
  29. from mindspore.common import set_seed
  30. from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
  31. from src.logger import get_logger
  32. from src.util import AverageMeter, get_param_groups
  33. from src.lr_scheduler import get_lr
  34. from src.yolo_dataset import create_yolo_dataset
  35. from src.initializer import default_recurisive_init, load_yolov3_quant_params
  36. from src.config import ConfigYOLOV3DarkNet53
  37. from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
  38. from src.util import ShapeRecord
  39. set_seed(1)
  40. devid = int(os.getenv('DEVICE_ID'))
  41. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
  42. device_target="Ascend", save_graphs=True, device_id=devid)
  43. def parse_args():
  44. """Parse train arguments."""
  45. parser = argparse.ArgumentParser('mindspore coco training')
  46. # dataset related
  47. parser.add_argument('--data_dir', type=str, default='', help='Train data dir. Default: ""')
  48. parser.add_argument('--per_batch_size', default=16, type=int, help='Batch size for per device. Default: 16')
  49. # network related
  50. parser.add_argument('--resume_yolov3', default='', type=str,\
  51. help='The ckpt file of yolov3-darknet53, which used to yolov3-darknet53 quant. Default: ""')
  52. # optimizer and lr related
  53. parser.add_argument('--lr_scheduler', default='exponential', type=str,\
  54. help='Learning rate scheduler, option type: exponential, '
  55. 'cosine_annealing. Default: exponential')
  56. parser.add_argument('--lr', default=0.012, type=float, help='Learning rate of the training')
  57. parser.add_argument('--lr_epochs', type=str, default='92,105',\
  58. help='Epoch of lr changing. Default: 92,105')
  59. parser.add_argument('--lr_gamma', type=float, default=0.1,\
  60. help='Decrease lr by a factor of exponential lr_scheduler. Default: 0.1')
  61. parser.add_argument('--eta_min', type=float, default=0.,\
  62. help='Eta_min in cosine_annealing scheduler. Default: 0.')
  63. parser.add_argument('--T_max', type=int, default=135,\
  64. help='T-max in cosine_annealing scheduler. Default: 135')
  65. parser.add_argument('--max_epoch', type=int, default=135,\
  66. help='Max epoch num to train the model. Default: 135')
  67. parser.add_argument('--warmup_epochs', type=float, default=0, help='Warmup epochs. Default: 0')
  68. parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay. Default: 0.0005')
  69. parser.add_argument('--momentum', type=float, default=0.9, help='Momentum. Default: 0.9')
  70. # loss related
  71. parser.add_argument('--loss_scale', type=int, default=1024, help='Static loss scale. Default: 1024')
  72. parser.add_argument('--label_smooth', type=int, default=0, help='Whether to use label smooth in CE. Default: 0')
  73. parser.add_argument('--label_smooth_factor', type=float, default=0.1,\
  74. help='Smooth strength of original one-hot. Default: 0.1')
  75. # logging related
  76. parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
  77. parser.add_argument('--ckpt_path', type=str, default='outputs/',\
  78. help='Checkpoint save location. Default: "outputs/"')
  79. parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
  80. parser.add_argument('--is_save_on_master', type=int, default=1,\
  81. help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 1')
  82. # distributed related
  83. parser.add_argument('--is_distributed', type=int, default=0,\
  84. help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
  85. parser.add_argument('--rank', type=int, default=0, help='Local rank of distributed, Default: 0')
  86. parser.add_argument('--group_size', type=int, default=1, help='World size of device, Default: 1')
  87. # profiler init
  88. parser.add_argument('--need_profiler', type=int, default=0,\
  89. help='Whether use profiler, 1 for yes, 0 for no, Default: 0')
  90. # reset default config
  91. parser.add_argument('--training_shape', type=str, default="", help='Fix training shape. Default: ""')
  92. parser.add_argument('--resize_rate', type=int, default=None,\
  93. help='Resize rate for multi-scale training. Default: None')
  94. args, _ = parser.parse_known_args()
  95. if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.T_max:
  96. args.T_max = args.max_epoch
  97. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  98. args.data_root = os.path.join(args.data_dir, 'train2014')
  99. args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json')
  100. # init distributed
  101. if args.is_distributed:
  102. init()
  103. args.rank = get_rank()
  104. args.group_size = get_group_size()
  105. # select for master rank save ckpt or all rank save, compatiable for model parallel
  106. args.rank_save_ckpt_flag = 0
  107. if args.is_save_on_master:
  108. if args.rank == 0:
  109. args.rank_save_ckpt_flag = 1
  110. else:
  111. args.rank_save_ckpt_flag = 1
  112. # logger
  113. args.outputs_dir = os.path.join(args.ckpt_path,
  114. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  115. args.logger = get_logger(args.outputs_dir, args.rank)
  116. return args
  117. def conver_training_shape(args):
  118. training_shape = [int(args.training_shape), int(args.training_shape)]
  119. return training_shape
  120. def train():
  121. """Train function."""
  122. args = parse_args()
  123. args.logger.save_args(args)
  124. if args.need_profiler:
  125. from mindspore.profiler.profiling import Profiler
  126. profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
  127. loss_meter = AverageMeter('loss')
  128. context.reset_auto_parallel_context()
  129. parallel_mode = ParallelMode.STAND_ALONE
  130. degree = 1
  131. if args.is_distributed:
  132. parallel_mode = ParallelMode.DATA_PARALLEL
  133. degree = get_group_size()
  134. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
  135. network = YOLOV3DarkNet53(is_training=True)
  136. # default is kaiming-normal
  137. default_recurisive_init(network)
  138. load_yolov3_quant_params(args, network)
  139. config = ConfigYOLOV3DarkNet53()
  140. # convert fusion network to quantization aware network
  141. if config.quantization_aware:
  142. quantizer = QuantizationAwareTraining(bn_fold=True,
  143. per_channel=[True, False],
  144. symmetric=[True, False])
  145. network = quantizer.quantize(network)
  146. network = YoloWithLossCell(network)
  147. args.logger.info('finish get network')
  148. config.label_smooth = args.label_smooth
  149. config.label_smooth_factor = args.label_smooth_factor
  150. if args.training_shape:
  151. config.multi_scale = [conver_training_shape(args)]
  152. if args.resize_rate:
  153. config.resize_rate = args.resize_rate
  154. ds, data_size = create_yolo_dataset(image_dir=args.data_root, anno_path=args.annFile, is_training=True,
  155. batch_size=args.per_batch_size, max_epoch=args.max_epoch,
  156. device_num=args.group_size, rank=args.rank, config=config)
  157. args.logger.info('Finish loading dataset')
  158. args.steps_per_epoch = int(data_size / args.per_batch_size / args.group_size)
  159. if not args.ckpt_interval:
  160. args.ckpt_interval = args.steps_per_epoch
  161. lr = get_lr(args)
  162. opt = Momentum(params=get_param_groups(network),
  163. learning_rate=Tensor(lr),
  164. momentum=args.momentum,
  165. weight_decay=args.weight_decay,
  166. loss_scale=args.loss_scale)
  167. network = TrainingWrapper(network, opt)
  168. network.set_train()
  169. if args.rank_save_ckpt_flag:
  170. # checkpoint save
  171. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  172. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
  173. keep_checkpoint_max=ckpt_max_num)
  174. save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
  175. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  176. directory=save_ckpt_path,
  177. prefix='{}'.format(args.rank))
  178. cb_params = _InternalCallbackParam()
  179. cb_params.train_network = network
  180. cb_params.epoch_num = ckpt_max_num
  181. cb_params.cur_epoch_num = 1
  182. run_context = RunContext(cb_params)
  183. ckpt_cb.begin(run_context)
  184. old_progress = -1
  185. t_end = time.time()
  186. data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
  187. shape_record = ShapeRecord()
  188. for i, data in enumerate(data_loader):
  189. images = data["image"]
  190. input_shape = images.shape[2:4]
  191. args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
  192. shape_record.set(input_shape)
  193. images = Tensor.from_numpy(images)
  194. annos = data["annotation"]
  195. if args.group_size == 1:
  196. batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
  197. batch_preprocess_true_box(annos, config, input_shape)
  198. else:
  199. batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
  200. batch_preprocess_true_box_single(annos, config, input_shape)
  201. batch_y_true_0 = Tensor.from_numpy(batch_y_true_0)
  202. batch_y_true_1 = Tensor.from_numpy(batch_y_true_1)
  203. batch_y_true_2 = Tensor.from_numpy(batch_y_true_2)
  204. batch_gt_box0 = Tensor.from_numpy(batch_gt_box0)
  205. batch_gt_box1 = Tensor.from_numpy(batch_gt_box1)
  206. batch_gt_box2 = Tensor.from_numpy(batch_gt_box2)
  207. input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
  208. loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
  209. batch_gt_box2, input_shape)
  210. loss_meter.update(loss.asnumpy())
  211. if args.rank_save_ckpt_flag:
  212. # ckpt progress
  213. cb_params.cur_step_num = i + 1 # current step number
  214. cb_params.batch_num = i + 2
  215. ckpt_cb.step_end(run_context)
  216. if i % args.log_interval == 0:
  217. time_used = time.time() - t_end
  218. epoch = int(i / args.steps_per_epoch)
  219. fps = args.per_batch_size * (i - old_progress) * args.group_size / time_used
  220. if args.rank == 0:
  221. args.logger.info(
  222. 'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(epoch, i, loss_meter, fps, lr[i]))
  223. t_end = time.time()
  224. loss_meter.reset()
  225. old_progress = i
  226. if (i + 1) % args.steps_per_epoch == 0 and args.rank_save_ckpt_flag:
  227. cb_params.cur_epoch_num += 1
  228. if args.need_profiler:
  229. if i == 10:
  230. profiler.analyse()
  231. break
  232. args.logger.info('==========end training===============')
  233. if __name__ == "__main__":
  234. train()