You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Recognition train."""
  16. import os
  17. import argparse
  18. import mindspore
  19. from mindspore.nn import Cell
  20. from mindspore import context
  21. from mindspore.context import ParallelMode
  22. from mindspore.communication.management import get_group_size, init, get_rank
  23. from mindspore.nn.optim import Momentum
  24. from mindspore.train.model import Model
  25. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  26. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  27. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  28. from src.config import config_base, config_beta
  29. from src.my_logging import get_logger
  30. from src.init_network import init_net
  31. from src.dataset_factory import get_de_dataset
  32. from src.backbone.resnet import get_backbone
  33. from src.metric_factory import get_metric_fc
  34. from src.loss_factory import get_loss
  35. from src.lrsche_factory import warmup_step_list, list_to_gen
  36. from src.callback_factory import ProgressMonitor
  37. mindspore.common.seed.set_seed(1)
  38. devid = int(os.getenv('DEVICE_ID'))
  39. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
  40. device_id=devid, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
  41. class DistributedHelper(Cell):
  42. '''DistributedHelper'''
  43. def __init__(self, backbone, margin_fc):
  44. super(DistributedHelper, self).__init__()
  45. self.backbone = backbone
  46. self.margin_fc = margin_fc
  47. if margin_fc is not None:
  48. self.has_margin_fc = 1
  49. else:
  50. self.has_margin_fc = 0
  51. def construct(self, x, label):
  52. embeddings = self.backbone(x)
  53. if self.has_margin_fc == 1:
  54. return embeddings, self.margin_fc(embeddings, label)
  55. return embeddings
  56. class BuildTrainNetwork(Cell):
  57. '''BuildTrainNetwork'''
  58. def __init__(self, network, criterion, args_1):
  59. super(BuildTrainNetwork, self).__init__()
  60. self.network = network
  61. self.criterion = criterion
  62. self.args = args_1
  63. if int(args_1.model_parallel) == 0:
  64. self.is_model_parallel = 0
  65. else:
  66. self.is_model_parallel = 1
  67. def construct(self, input_data, label):
  68. if self.is_model_parallel == 0:
  69. _, output = self.network(input_data, label)
  70. loss = self.criterion(output, label)
  71. else:
  72. _ = self.network(input_data, label)
  73. loss = self.criterion(None, label)
  74. return loss
  75. def parse_args():
  76. parser = argparse.ArgumentParser('MindSpore Face Recognition')
  77. parser.add_argument('--train_stage', type=str, default='base', help='train stage, base or beta')
  78. parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
  79. args_opt_1, _ = parser.parse_known_args()
  80. return args_opt_1
  81. if __name__ == "__main__":
  82. args_opt = parse_args()
  83. support_train_stage = ['base', 'beta']
  84. if args_opt.train_stage.lower() not in support_train_stage:
  85. args.logger.info('support train stage is:{}, while yours is:{}'.
  86. format(support_train_stage, args_opt.train_stage))
  87. raise ValueError('train stage not support.')
  88. args = config_base if args_opt.train_stage.lower() == 'base' else config_beta
  89. args.is_distributed = args_opt.is_distributed
  90. if args_opt.is_distributed:
  91. init()
  92. args.local_rank = get_rank()
  93. args.world_size = get_group_size()
  94. parallel_mode = ParallelMode.HYBRID_PARALLEL
  95. else:
  96. parallel_mode = ParallelMode.STAND_ALONE
  97. context.set_auto_parallel_context(parallel_mode=parallel_mode,
  98. device_num=args.world_size, gradients_mean=True)
  99. if not os.path.exists(args.data_dir):
  100. args.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py')
  101. raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py')
  102. args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
  103. log_path = os.path.join(args.ckpt_path, 'logs')
  104. args.logger = get_logger(log_path, args.local_rank)
  105. if args.local_rank % 8 == 0:
  106. if not os.path.exists(args.ckpt_path):
  107. os.makedirs(args.ckpt_path)
  108. args.logger.info('args.world_size:{}'.format(args.world_size))
  109. args.logger.info('args.local_rank:{}'.format(args.local_rank))
  110. args.logger.info('args.lr:{}'.format(args.lr))
  111. momentum = args.momentum
  112. weight_decay = args.weight_decay
  113. de_dataset, steps_per_epoch, num_classes = get_de_dataset(args)
  114. args.logger.info('de_dataset:{}'.format(de_dataset.get_dataset_size()))
  115. args.steps_per_epoch = steps_per_epoch
  116. args.num_classes = num_classes
  117. args.logger.info('loaded, nums: {}'.format(args.num_classes))
  118. if args.nc_16 == 1:
  119. if args.model_parallel == 0:
  120. if args.num_classes % 16 == 0:
  121. args.logger.info('data parallel aleardy 16, nums: {}'.format(args.num_classes))
  122. else:
  123. args.num_classes = (args.num_classes // 16 + 1) * 16
  124. else:
  125. if args.num_classes % (args.world_size * 16) == 0:
  126. args.logger.info('model parallel aleardy 16, nums: {}'.format(args.num_classes))
  127. else:
  128. args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16
  129. args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes))
  130. args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch))
  131. args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size))
  132. args.logger.info('get_backbone----in----')
  133. _backbone = get_backbone(args)
  134. args.logger.info('get_backbone----out----')
  135. args.logger.info('get_metric_fc----in----')
  136. margin_fc_1 = get_metric_fc(args)
  137. args.logger.info('get_metric_fc----out----')
  138. args.logger.info('DistributedHelper----in----')
  139. network_1 = DistributedHelper(_backbone, margin_fc_1)
  140. args.logger.info('DistributedHelper----out----')
  141. args.logger.info('network fp16----in----')
  142. if args.fp16 == 1:
  143. network_1.add_flags_recursive(fp16=True)
  144. args.logger.info('network fp16----out----')
  145. criterion_1 = get_loss(args)
  146. if args.fp16 == 1 and args.model_parallel == 0:
  147. criterion_1.add_flags_recursive(fp32=True)
  148. if os.path.isfile(args.pretrained):
  149. param_dict = load_checkpoint(args.pretrained)
  150. param_dict_new = {}
  151. if args_opt.train_stage.lower() == 'base':
  152. for key, value in param_dict.items():
  153. if key.startswith('moments.'):
  154. continue
  155. elif key.startswith('network.'):
  156. param_dict_new[key[8:]] = value
  157. else:
  158. for key, value in param_dict.items():
  159. if key.startswith('moments.'):
  160. continue
  161. elif key.startswith('network.'):
  162. if 'layers.' in key and 'bn1' in key:
  163. continue
  164. elif 'se' in key:
  165. continue
  166. elif 'head' in key:
  167. continue
  168. elif 'margin_fc.weight' in key:
  169. continue
  170. else:
  171. param_dict_new[key[8:]] = value
  172. load_param_into_net(network_1, param_dict_new)
  173. args.logger.info('load model {} success'.format(args.pretrained))
  174. else:
  175. init_net(args, network_1)
  176. train_net = BuildTrainNetwork(network_1, criterion_1, args)
  177. args.logger.info('args:{}'.format(args))
  178. # call warmup_step should behind the args steps_per_epoch
  179. args.lrs = warmup_step_list(args, gamma=0.1)
  180. lrs_gen = list_to_gen(args.lrs)
  181. opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=momentum,
  182. weight_decay=weight_decay)
  183. scale_manager = DynamicLossScaleManager(init_loss_scale=args.dynamic_init_loss_scale, scale_factor=2,
  184. scale_window=2000)
  185. model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager)
  186. save_checkpoint_steps = args.ckpt_steps
  187. args.logger.info('save_checkpoint_steps:{}'.format(save_checkpoint_steps))
  188. if args.max_ckpts == -1:
  189. keep_checkpoint_max = int(args.steps_per_epoch * args.max_epoch / save_checkpoint_steps) + 5 # for more than 5
  190. else:
  191. keep_checkpoint_max = args.max_ckpts
  192. args.logger.info('keep_checkpoint_max:{}'.format(keep_checkpoint_max))
  193. ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max)
  194. max_epoch_train = args.max_epoch
  195. args.logger.info('max_epoch_train:{}'.format(max_epoch_train))
  196. ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.ckpt_path, prefix='{}'.format(args.local_rank))
  197. args.epoch_cnt = 0
  198. progress_cb = ProgressMonitor(args)
  199. new_epoch_train = max_epoch_train * steps_per_epoch // args.log_interval
  200. model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=args.log_interval)