You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Train centerface and get network model files(.ckpt)
  17. """
  18. import os
  19. import time
  20. import argparse
  21. import datetime
  22. import numpy as np
  23. from mindspore import context
  24. from mindspore.context import ParallelMode
  25. from mindspore.nn.optim.adam import Adam
  26. from mindspore.nn.optim.momentum import Momentum
  27. from mindspore.nn.optim.sgd import SGD
  28. from mindspore import Tensor
  29. from mindspore.communication.management import init, get_rank, get_group_size
  30. from mindspore.train.callback import ModelCheckpoint, RunContext
  31. from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
  32. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  33. from mindspore.profiler.profiling import Profiler
  34. from mindspore.common import set_seed
  35. from src.utils import get_logger
  36. from src.utils import AverageMeter
  37. from src.lr_scheduler import warmup_step_lr
  38. from src.lr_scheduler import warmup_cosine_annealing_lr, \
  39. warmup_cosine_annealing_lr_v2, warmup_cosine_annealing_lr_sample
  40. from src.lr_scheduler import MultiStepLR
  41. from src.var_init import default_recurisive_init
  42. from src.centerface import CenterfaceMobilev2
  43. from src.utils import load_backbone, get_param_groups
  44. from src.config import ConfigCenterface
  45. from src.centerface import CenterFaceWithLossCell, TrainingWrapper
  46. from src.dataset import GetDataLoader
  47. set_seed(1)
  48. dev_id = int(os.getenv('DEVICE_ID'))
  49. context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
  50. device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False)
  51. parser = argparse.ArgumentParser('mindspore coco training')
  52. # dataset related
  53. parser.add_argument('--data_dir', type=str, default='', help='train data dir')
  54. parser.add_argument('--annot_path', type=str, default='', help='train data annotation path')
  55. parser.add_argument('--img_dir', type=str, default='', help='train data img dir')
  56. parser.add_argument('--per_batch_size', default=8, type=int, help='batch size for per gpu')
  57. # network related
  58. parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone'
  59. ' model to load')
  60. parser.add_argument('--resume', default='', type=str, help='path of pretrained centerface_model')
  61. # optimizer and lr related
  62. parser.add_argument('--lr_scheduler', default='multistep', type=str,
  63. help='lr-scheduler, option type: exponential, cosine_annealing')
  64. parser.add_argument('--lr', default=4e-3, type=float, help='learning rate of the training')
  65. parser.add_argument('--lr_epochs', type=str, default='90,120', help='epoch of lr changing')
  66. parser.add_argument('--lr_gamma', type=float, default=0.1,
  67. help='decrease lr by a factor of exponential lr_scheduler')
  68. parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler')
  69. parser.add_argument('--t_max', type=int, default=140, help='T-max in cosine_annealing scheduler')
  70. parser.add_argument('--max_epoch', type=int, default=140, help='max epoch num to train the model')
  71. parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch')
  72. parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
  73. parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  74. parser.add_argument('--optimizer', default='adam', type=str,
  75. help='optimizer type, default: adam')
  76. # loss related
  77. parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale')
  78. parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE')
  79. parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot')
  80. # logging related
  81. parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
  82. parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
  83. parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval')
  84. parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
  85. # distributed related
  86. parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
  87. parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
  88. parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
  89. # roma obs
  90. parser.add_argument('--train_url', type=str, default="", help='train url')
  91. # profiler init, can open when you debug. if train, donot open, since it cost memory and disk space
  92. parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler')
  93. # reset default config
  94. parser.add_argument('--training_shape', type=str, default="", help='fix training shape')
  95. parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training')
  96. args, _ = parser.parse_known_args()
  97. if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max:
  98. args.t_max = args.max_epoch
  99. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  100. def convert_training_shape(args_):
  101. """
  102. Convert training shape
  103. """
  104. training_shape = [int(args_.training_shape), int(args_.training_shape)]
  105. return training_shape
  106. if __name__ == "__main__":
  107. # init distributed
  108. if args.is_distributed:
  109. init()
  110. args.rank = get_rank()
  111. args.group_size = get_group_size()
  112. # select for master rank save ckpt or all rank save, compatible for model parallel
  113. args.rank_save_ckpt_flag = 0
  114. if args.is_save_on_master:
  115. if args.rank == 0:
  116. args.rank_save_ckpt_flag = 1
  117. else:
  118. args.rank_save_ckpt_flag = 1
  119. # logger
  120. args.outputs_dir = os.path.join(args.ckpt_path,
  121. datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  122. args.logger = get_logger(args.outputs_dir, args.rank)
  123. args.logger.save_args(args)
  124. if args.need_profiler:
  125. profiler = Profiler(output_path=args.outputs_dir)
  126. loss_meter = AverageMeter('loss')
  127. context.reset_auto_parallel_context()
  128. if args.is_distributed:
  129. parallel_mode = ParallelMode.DATA_PARALLEL
  130. degree = get_group_size()
  131. else:
  132. parallel_mode = ParallelMode.STAND_ALONE
  133. degree = 1
  134. # context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=degree, parameter_broadcast=True, gradients_mean=True)
  135. # Notice: parameter_broadcast should be supported, but current version has bugs, thus been disabled.
  136. # To make sure the init weight on all npu is the same, we need to set a static seed in default_recurisive_init when weight initialization
  137. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
  138. network = CenterfaceMobilev2()
  139. # init, to avoid overflow, some std of weight should be enough small
  140. default_recurisive_init(network)
  141. if args.pretrained_backbone:
  142. network = load_backbone(network, args.pretrained_backbone, args)
  143. args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
  144. else:
  145. args.logger.info('Not load pre-trained backbone, please be careful')
  146. if os.path.isfile(args.resume):
  147. param_dict = load_checkpoint(args.resume)
  148. param_dict_new = {}
  149. for key, values in param_dict.items():
  150. if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
  151. continue
  152. elif key.startswith('centerface_network.'):
  153. param_dict_new[key[19:]] = values
  154. else:
  155. param_dict_new[key] = values
  156. load_param_into_net(network, param_dict_new)
  157. args.logger.info('load_model {} success'.format(args.resume))
  158. else:
  159. args.logger.info('{} not set/exists or not a pre-trained file'.format(args.resume))
  160. network = CenterFaceWithLossCell(network)
  161. args.logger.info('finish get network')
  162. config = ConfigCenterface()
  163. config.data_dir = args.data_dir
  164. config.annot_path = args.annot_path
  165. config.img_dir = args.img_dir
  166. config.label_smooth = args.label_smooth
  167. config.label_smooth_factor = args.label_smooth_factor
  168. # -------------reset config-----------------
  169. if args.training_shape:
  170. config.multi_scale = [convert_training_shape(args)]
  171. if args.resize_rate:
  172. config.resize_rate = args.resize_rate
  173. # data loader
  174. data_loader, args.steps_per_epoch = GetDataLoader(per_batch_size=args.per_batch_size,
  175. max_epoch=args.max_epoch,
  176. rank=args.rank,
  177. group_size=args.group_size,
  178. config=config,
  179. split='train')
  180. args.steps_per_epoch = args.steps_per_epoch // args.max_epoch
  181. args.logger.info('Finish loading dataset')
  182. if not args.ckpt_interval:
  183. args.ckpt_interval = args.steps_per_epoch
  184. # lr scheduler
  185. if args.lr_scheduler == 'multistep':
  186. lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch, args.max_epoch,
  187. args.warmup_epochs)
  188. lr = lr_fun.get_lr()
  189. elif args.lr_scheduler == 'exponential':
  190. lr = warmup_step_lr(args.lr,
  191. args.lr_epochs,
  192. args.steps_per_epoch,
  193. args.warmup_epochs,
  194. args.max_epoch,
  195. gamma=args.lr_gamma
  196. )
  197. elif args.lr_scheduler == 'cosine_annealing':
  198. lr = warmup_cosine_annealing_lr(args.lr,
  199. args.steps_per_epoch,
  200. args.warmup_epochs,
  201. args.max_epoch,
  202. args.t_max,
  203. args.eta_min)
  204. elif args.lr_scheduler == 'cosine_annealing_V2':
  205. lr = warmup_cosine_annealing_lr_v2(args.lr,
  206. args.steps_per_epoch,
  207. args.warmup_epochs,
  208. args.max_epoch,
  209. args.t_max,
  210. args.eta_min)
  211. elif args.lr_scheduler == 'cosine_annealing_sample':
  212. lr = warmup_cosine_annealing_lr_sample(args.lr,
  213. args.steps_per_epoch,
  214. args.warmup_epochs,
  215. args.max_epoch,
  216. args.t_max,
  217. args.eta_min)
  218. else:
  219. raise NotImplementedError(args.lr_scheduler)
  220. if args.optimizer == "adam":
  221. opt = Adam(params=get_param_groups(network),
  222. learning_rate=Tensor(lr),
  223. weight_decay=args.weight_decay,
  224. loss_scale=args.loss_scale)
  225. args.logger.info("use adam optimizer")
  226. elif args.optimizer == "sgd":
  227. opt = SGD(params=get_param_groups(network),
  228. learning_rate=Tensor(lr),
  229. momentum=args.momentum,
  230. weight_decay=args.weight_decay,
  231. loss_scale=args.loss_scale)
  232. else:
  233. opt = Momentum(params=get_param_groups(network),
  234. learning_rate=Tensor(lr),
  235. momentum=args.momentum,
  236. weight_decay=args.weight_decay,
  237. loss_scale=args.loss_scale)
  238. network = TrainingWrapper(network, opt, sens=args.loss_scale)
  239. network.set_train()
  240. if args.rank_save_ckpt_flag:
  241. # checkpoint save
  242. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  243. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
  244. keep_checkpoint_max=ckpt_max_num)
  245. ckpt_cb = ModelCheckpoint(config=ckpt_config,
  246. directory=args.outputs_dir,
  247. prefix='{}'.format(args.rank))
  248. cb_params = _InternalCallbackParam()
  249. cb_params.train_network = network
  250. cb_params.epoch_num = ckpt_max_num
  251. cb_params.cur_epoch_num = 1
  252. run_context = RunContext(cb_params)
  253. ckpt_cb.begin(run_context)
  254. args.logger.info('args.steps_per_epoch = {} args.ckpt_interval ={}'.format(args.steps_per_epoch,
  255. args.ckpt_interval))
  256. t_end = time.time()
  257. for i_all, batch_load in enumerate(data_loader):
  258. i = i_all % args.steps_per_epoch
  259. epoch = i_all // args.steps_per_epoch + 1
  260. images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks = batch_load
  261. images = Tensor(images)
  262. hm = Tensor(hm)
  263. reg_mask = Tensor(reg_mask)
  264. ind = Tensor(ind)
  265. wh = Tensor(wh)
  266. wight_mask = Tensor(wight_mask)
  267. hm_offset = Tensor(hm_offset)
  268. hps_mask = Tensor(hps_mask)
  269. landmarks = Tensor(landmarks)
  270. loss, overflow, scaling = network(images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks)
  271. # Tensor to numpy
  272. overflow = np.all(overflow.asnumpy())
  273. loss = loss.asnumpy()
  274. loss_meter.update(loss)
  275. args.logger.info('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format(epoch,
  276. i,
  277. loss_meter,
  278. loss,
  279. overflow,
  280. scaling.asnumpy()
  281. ))
  282. if args.rank_save_ckpt_flag:
  283. # ckpt progress
  284. cb_params.cur_epoch_num = epoch
  285. cb_params.cur_step_num = i + 1 + (epoch-1)*args.steps_per_epoch
  286. cb_params.batch_num = i + 2 + (epoch-1)*args.steps_per_epoch
  287. ckpt_cb.step_end(run_context)
  288. if (i_all+1) % args.steps_per_epoch == 0:
  289. time_used = time.time() - t_end
  290. fps = args.per_batch_size * args.steps_per_epoch * args.group_size / time_used
  291. if args.rank == 0:
  292. args.logger.info(
  293. 'epoch[{}], {}, {:.2f} imgs/sec, lr:{}'
  294. .format(epoch, loss_meter, fps, lr[i + (epoch-1)*args.steps_per_epoch])
  295. )
  296. t_end = time.time()
  297. loss_meter.reset()
  298. if args.need_profiler:
  299. profiler.analyse()
  300. args.logger.info('==========end training===============')