You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Train centerface and get network model files(.ckpt)
  17. """
  18. import os
  19. import time
  20. import argparse
  21. import datetime
  22. import numpy as np
  23. from mindspore import context
  24. from mindspore.context import ParallelMode
  25. from mindspore.nn.optim.adam import Adam
  26. from mindspore.nn.optim.momentum import Momentum
  27. from mindspore.nn.optim.sgd import SGD
  28. from mindspore import Tensor
  29. from mindspore.communication.management import init, get_rank, get_group_size
  30. from mindspore.train.callback import ModelCheckpoint, RunContext
  31. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  32. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  33. from mindspore.profiler.profiling import Profiler
  34. from src.utils import get_logger
  35. from src.utils import AverageMeter
  36. from src.lr_scheduler import warmup_step_lr
  37. from src.lr_scheduler import warmup_cosine_annealing_lr, \
  38. warmup_cosine_annealing_lr_v2, warmup_cosine_annealing_lr_sample
  39. from src.lr_scheduler import MultiStepLR
  40. from src.var_init import default_recurisive_init
  41. from src.centerface import CenterfaceMobilev2
  42. from src.utils import load_backbone, get_param_groups
  43. from src.config import ConfigCenterface
  44. from src.centerface import CenterFaceWithLossCell, TrainingWrapper
  45. from src.dataset import GetDataLoader
  46. dev_id = int(os.getenv('DEVICE_ID'))
  47. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
  48. device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
  49. parser = argparse.ArgumentParser('mindspore coco training')
  50. # dataset related
  51. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  52. parser.add_argument('--annot_path', type=str, default='', help='train data annotation path')
  53. parser.add_argument('--img_dir', type=str, default='', help='train data img dir')
  54. parser.add_argument('--per_batch_size', default=8, type=int, help='batch size for per gpu')
  55. # network related
  56. parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
  57. ' model to load')
  58. parser.add_argument('--resume', default='', type=str, help='path of pretrained centerface_model')
  59. # optimizer and lr related
  60. parser.add_argument('--lr_scheduler', default='multistep', type=str,
  61. help='lr-scheduler, option type: exponential, cosine_annealing')
  62. parser.add_argument('--lr', default=4e-3, type=float, help='learning rate of the training')
  63. parser.add_argument('--lr_epochs', type=str, default='90,120', help='epoch of lr changing')
  64. parser.add_argument('--lr_gamma', type=float, default=0.1,
  65. help='decrease lr by a factor of exponential lr_scheduler')
  66. parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
  67. parser.add_argument('--t_max', type=int, default=140, help='T-max in cosine_annealing scheduler')
  68. parser.add_argument('--max_epoch', type=int, default=140, help='max epoch num to train the model')
  69. parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
  70. parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
  71. parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  72. parser.add_argument('--optimizer', default='adam', type=str,
  73. help='optimizer type, default: adam')
  74. # loss related
  75. parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
  76. parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
  77. parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
  78. # logging related
  79. parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
  80. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
  81. parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
  82. parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
  83. # distributed related
  84. parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
  85. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  86. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  87. # roma obs
  88. parser.add_argument('--train_url', type=str, default="", help='train url')
  89. # profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
  90. parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
  91. # reset default config
  92. parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
  93. parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
  94. args, _ = parser.parse_known_args()
  95. if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
  96. args.t_max = args.max_epoch
  97. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  98. def convert_training_shape(args_):
  99. """
  100. Convert training shape
  101. """
  102. training_shape = [int(args_.training_shape), int(args_.training_shape)]
  103. return training_shape
  104. if __name__ == "__main__":
  105. # init distributed
  106. if args.is_distributed:
  107. init()
  108. args.rank = get_rank()
  109. args.group_size = get_group_size()
  110. # select for master rank save ckpt or all rank save, compatiable for model parallel
  111. args.rank_save_ckpt_flag = 0
  112. if args.is_save_on_master:
  113. if args.rank == 0:
  114. args.rank_save_ckpt_flag = 1
  115. else:
  116. args.rank_save_ckpt_flag = 1
  117. # logger
  118. args.outputs_dir = os.path.join(args.ckpt_path,
  119. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  120. args.logger = get_logger(args.outputs_dir, args.rank)
  121. args.logger.save_args(args)
  122. if args.need_profiler:
  123. profiler = Profiler(output_path=args.outputs_dir)
  124. loss_meter = AverageMeter('loss')
  125. context.reset_auto_parallel_context()
  126. if args.is_distributed:
  127. parallel_mode = ParallelMode.DATA_PARALLEL
  128. degree = get_group_size()
  129. else:
  130. parallel_mode = ParallelMode.STAND_ALONE
  131. degree = 1
  132. # context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=degree, parameter_broadcast=True, gradients_mean=True)
  133. # Notice: parameter_broadcast should be supported, but current version has bugs, thus been disabled.
  134. # To make sure the init weight on all npu is the same, we need to set a static seed in default_recurisive_init when weight initialization
  135. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
  136. network = CenterfaceMobilev2()
  137. # init, to avoid overflow, some std of weight should be enough small
  138. default_recurisive_init(network)
  139. if args.pretrained_backbone:
  140. network = load_backbone(network, args.pretrained_backbone, args)
  141. args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
  142. else:
  143. args.logger.info('Not load pre-trained backbone, please be careful')
  144. if os.path.isfile(args.resume):
  145. param_dict = load_checkpoint(args.resume)
  146. param_dict_new = {}
  147. for key, values in param_dict.items():
  148. if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
  149. continue
  150. elif key.startswith('centerface_network.'):
  151. param_dict_new[key[19:]] = values
  152. else:
  153. param_dict_new[key] = values
  154. load_param_into_net(network, param_dict_new)
  155. args.logger.info('load_model {} success'.format(args.resume))
  156. else:
  157. args.logger.info('{} not set/exists or not a pre-trained file'.format(args.resume))
  158. network = CenterFaceWithLossCell(network)
  159. args.logger.info('finish get network')
  160. config = ConfigCenterface()
  161. config.data_dir = args.data_dir
  162. config.annot_path = args.annot_path
  163. config.img_dir = args.img_dir
  164. config.label_smooth = args.label_smooth
  165. config.label_smooth_factor = args.label_smooth_factor
  166. # -------------reset config-----------------
  167. if args.training_shape:
  168. config.multi_scale = [convert_training_shape(args)]
  169. if args.resize_rate:
  170. config.resize_rate = args.resize_rate
  171. # data loader
  172. data_loader, args.steps_per_epoch = GetDataLoader(per_batch_size=args.per_batch_size,
  173. max_epoch=args.max_epoch,
  174. rank=args.rank,
  175. group_size=args.group_size,
  176. config=config,
  177. split='train')
  178. args.steps_per_epoch = args.steps_per_epoch // args.max_epoch
  179. args.logger.info('Finish loading dataset')
  180. if not args.ckpt_interval:
  181. args.ckpt_interval = args.steps_per_epoch
  182. # lr scheduler
  183. if args.lr_scheduler == 'multistep':
  184. lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch, args.max_epoch,
  185. args.warmup_epochs)
  186. lr = lr_fun.get_lr()
  187. elif args.lr_scheduler == 'exponential':
  188. lr = warmup_step_lr(args.lr,
  189. args.lr_epochs,
  190. args.steps_per_epoch,
  191. args.warmup_epochs,
  192. args.max_epoch,
  193. gamma=args.lr_gamma
  194. )
  195. elif args.lr_scheduler == 'cosine_annealing':
  196. lr = warmup_cosine_annealing_lr(args.lr,
  197. args.steps_per_epoch,
  198. args.warmup_epochs,
  199. args.max_epoch,
  200. args.t_max,
  201. args.eta_min)
  202. elif args.lr_scheduler == 'cosine_annealing_V2':
  203. lr = warmup_cosine_annealing_lr_v2(args.lr,
  204. args.steps_per_epoch,
  205. args.warmup_epochs,
  206. args.max_epoch,
  207. args.t_max,
  208. args.eta_min)
  209. elif args.lr_scheduler == 'cosine_annealing_sample':
  210. lr = warmup_cosine_annealing_lr_sample(args.lr,
  211. args.steps_per_epoch,
  212. args.warmup_epochs,
  213. args.max_epoch,
  214. args.t_max,
  215. args.eta_min)
  216. else:
  217. raise NotImplementedError(args.lr_scheduler)
  218. if args.optimizer == "adam":
  219. opt = Adam(params=get_param_groups(network),
  220. learning_rate=Tensor(lr),
  221. weight_decay=args.weight_decay,
  222. loss_scale=args.loss_scale)
  223. args.logger.info("use adam optimizer")
  224. elif args.optimizer == "sgd":
  225. opt = SGD(params=get_param_groups(network),
  226. learning_rate=Tensor(lr),
  227. momentum=args.momentum,
  228. weight_decay=args.weight_decay,
  229. loss_scale=args.loss_scale)
  230. else:
  231. opt = Momentum(params=get_param_groups(network),
  232. learning_rate=Tensor(lr),
  233. momentum=args.momentum,
  234. weight_decay=args.weight_decay,
  235. loss_scale=args.loss_scale)
  236. network = TrainingWrapper(network, opt, sens=args.loss_scale)
  237. network.set_train()
  238. if args.rank_save_ckpt_flag:
  239. # checkpoint save
  240. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  241. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
  242. keep_checkpoint_max=ckpt_max_num)
  243. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  244. directory=args.outputs_dir,
  245. prefix='{}'.format(args.rank))
  246. cb_params = _InternalCallbackParam()
  247. cb_params.train_network = network
  248. cb_params.epoch_num = ckpt_max_num
  249. cb_params.cur_epoch_num = 1
  250. run_context = RunContext(cb_params)
  251. ckpt_cb.begin(run_context)
  252. args.logger.info('args.steps_per_epoch = {} args.ckpt_interval ={}'.format(args.steps_per_epoch,
  253. args.ckpt_interval))
  254. t_end = time.time()
  255. for i_all, batch_load in enumerate(data_loader):
  256. i = i_all % args.steps_per_epoch
  257. epoch = i_all // args.steps_per_epoch + 1
  258. images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks = batch_load
  259. images = Tensor(images)
  260. hm = Tensor(hm)
  261. reg_mask = Tensor(reg_mask)
  262. ind = Tensor(ind)
  263. wh = Tensor(wh)
  264. wight_mask = Tensor(wight_mask)
  265. hm_offset = Tensor(hm_offset)
  266. hps_mask = Tensor(hps_mask)
  267. landmarks = Tensor(landmarks)
  268. loss, overflow, scaling = network(images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks)
  269. # Tensor to numpy
  270. overflow = np.all(overflow.asnumpy())
  271. loss = loss.asnumpy()
  272. loss_meter.update(loss)
  273. args.logger.info('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format(epoch,
  274. i,
  275. loss_meter,
  276. loss,
  277. overflow,
  278. scaling.asnumpy()
  279. ))
  280. if args.rank_save_ckpt_flag:
  281. # ckpt progress
  282. cb_params.cur_epoch_num = epoch
  283. cb_params.cur_step_num = i + 1 + (epoch-1)*args.steps_per_epoch
  284. cb_params.batch_num = i + 2 + (epoch-1)*args.steps_per_epoch
  285. ckpt_cb.step_end(run_context)
  286. if (i_all+1) % args.steps_per_epoch == 0:
  287. time_used = time.time() - t_end
  288. fps = args.per_batch_size * args.steps_per_epoch * args.group_size / time_used
  289. if args.rank == 0:
  290. args.logger.info(
  291. 'epoch[{}], {}, {:.2f} imgs/sec, lr:{}'
  292. .format(epoch, loss_meter, fps, lr[i + (epoch-1)*args.steps_per_epoch])
  293. )
  294. t_end = time.time()
  295. loss_meter.reset()
  296. if args.need_profiler:
  297. profiler.analyse()
  298. args.logger.info('==========end training===============')