You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Recognition train."""
  16. import os
  17. import time
  18. import mindspore
  19. from mindspore.nn import Cell
  20. from mindspore import context
  21. from mindspore.context import ParallelMode
  22. from mindspore.communication.management import init
  23. from mindspore.nn.optim import Momentum
  24. from mindspore.train.model import Model
  25. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  26. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  27. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  28. from src.my_logging import get_logger
  29. from src.init_network import init_net
  30. from src.dataset_factory import get_de_dataset
  31. from src.backbone.resnet import get_backbone
  32. from src.metric_factory import get_metric_fc
  33. from src.loss_factory import get_loss
  34. from src.lrsche_factory import warmup_step_list, list_to_gen
  35. from src.callback_factory import ProgressMonitor
  36. from utils.moxing_adapter import moxing_wrapper
  37. from utils.config import config
  38. from utils.device_adapter import get_device_id, get_device_num, get_rank_id
  39. mindspore.common.seed.set_seed(1)
  40. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
  41. device_id=get_device_id(), reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
  42. class DistributedHelper(Cell):
  43. '''DistributedHelper'''
  44. def __init__(self, backbone, margin_fc):
  45. super(DistributedHelper, self).__init__()
  46. self.backbone = backbone
  47. self.margin_fc = margin_fc
  48. if margin_fc is not None:
  49. self.has_margin_fc = 1
  50. else:
  51. self.has_margin_fc = 0
  52. def construct(self, x, label):
  53. embeddings = self.backbone(x)
  54. if self.has_margin_fc == 1:
  55. return embeddings, self.margin_fc(embeddings, label)
  56. return embeddings
  57. class BuildTrainNetwork(Cell):
  58. '''BuildTrainNetwork'''
  59. def __init__(self, network, criterion, args_1):
  60. super(BuildTrainNetwork, self).__init__()
  61. self.network = network
  62. self.criterion = criterion
  63. self.args = args_1
  64. if int(args_1.model_parallel) == 0:
  65. self.is_model_parallel = 0
  66. else:
  67. self.is_model_parallel = 1
  68. def construct(self, input_data, label):
  69. if self.is_model_parallel == 0:
  70. _, output = self.network(input_data, label)
  71. loss = self.criterion(output, label)
  72. else:
  73. _ = self.network(input_data, label)
  74. loss = self.criterion(None, label)
  75. return loss
  76. def load_pretrain(cfg, net):
  77. '''load pretrain function.'''
  78. if os.path.isfile(cfg.pretrained):
  79. param_dict = load_checkpoint(cfg.pretrained)
  80. param_dict_new = {}
  81. if cfg.train_stage.lower() == 'base':
  82. for key, value in param_dict.items():
  83. if key.startswith('moments.'):
  84. continue
  85. elif key.startswith('network.'):
  86. param_dict_new[key[8:]] = value
  87. else:
  88. for key, value in param_dict.items():
  89. if key.startswith('moments.'):
  90. continue
  91. elif key.startswith('network.'):
  92. if 'layers.' in key and 'bn1' in key:
  93. continue
  94. elif 'se' in key:
  95. continue
  96. elif 'head' in key:
  97. continue
  98. elif 'margin_fc.weight' in key:
  99. continue
  100. else:
  101. param_dict_new[key[8:]] = value
  102. load_param_into_net(net, param_dict_new)
  103. cfg.logger.info('load model {} success'.format(cfg.pretrained))
  104. else:
  105. if cfg.train_stage.lower() == 'beta':
  106. raise ValueError("Train beta mode load pretrain model fail from: {}".format(cfg.pretrained))
  107. init_net(cfg, net)
  108. cfg.logger.info('init model success')
  109. return net
  110. def modelarts_pre_process():
  111. '''modelarts pre process function.'''
  112. def unzip(zip_file, save_dir):
  113. import zipfile
  114. s_time = time.time()
  115. if not os.path.exists(os.path.join(save_dir, "face_recognition_dataset")):
  116. zip_isexist = zipfile.is_zipfile(zip_file)
  117. if zip_isexist:
  118. fz = zipfile.ZipFile(zip_file, 'r')
  119. data_num = len(fz.namelist())
  120. print("Extract Start...")
  121. print("unzip file num: {}".format(data_num))
  122. i = 0
  123. for file in fz.namelist():
  124. if i % int(data_num / 100) == 0:
  125. print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True)
  126. i += 1
  127. fz.extract(file, save_dir)
  128. print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
  129. int(int(time.time() - s_time) % 60)))
  130. print("Extract Done.")
  131. else:
  132. print("This is not zip.")
  133. else:
  134. print("Zip has been extracted.")
  135. if config.need_modelarts_dataset_unzip:
  136. zip_file_1 = os.path.join(config.data_path, "face_recognition_dataset.zip")
  137. save_dir_1 = os.path.join(config.data_path)
  138. sync_lock = "/tmp/unzip_sync.lock"
  139. # Each server contains 8 devices as most.
  140. if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
  141. print("Zip file path: ", zip_file_1)
  142. print("Unzip file save dir: ", save_dir_1)
  143. unzip(zip_file_1, save_dir_1)
  144. print("===Finish extract data synchronization===")
  145. try:
  146. os.mknod(sync_lock)
  147. except IOError:
  148. pass
  149. while True:
  150. if os.path.exists(sync_lock):
  151. break
  152. time.sleep(1)
  153. print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
  154. config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
  155. @moxing_wrapper(pre_process=modelarts_pre_process)
  156. def run_train():
  157. '''run train function.'''
  158. config.local_rank = get_rank_id()
  159. config.world_size = get_device_num()
  160. log_path = os.path.join(config.ckpt_path, 'logs')
  161. config.logger = get_logger(log_path, config.local_rank)
  162. support_train_stage = ['base', 'beta']
  163. if config.train_stage.lower() not in support_train_stage:
  164. config.logger.info('your train stage is not support.')
  165. raise ValueError('train stage not support.')
  166. if not os.path.exists(config.data_dir):
  167. config.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py')
  168. raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py')
  169. parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE
  170. context.set_auto_parallel_context(parallel_mode=parallel_mode,
  171. device_num=config.world_size, gradients_mean=True)
  172. if config.is_distributed:
  173. init()
  174. if config.local_rank % 8 == 0:
  175. if not os.path.exists(config.ckpt_path):
  176. os.makedirs(config.ckpt_path)
  177. de_dataset, steps_per_epoch, num_classes = get_de_dataset(config)
  178. config.logger.info('de_dataset: %d', de_dataset.get_dataset_size())
  179. config.steps_per_epoch = steps_per_epoch
  180. config.num_classes = num_classes
  181. config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
  182. config.logger.info('config.num_classes: %d', config.num_classes)
  183. config.logger.info('config.world_size: %d', config.world_size)
  184. config.logger.info('config.local_rank: %d', config.local_rank)
  185. config.logger.info('config.lr: %f', config.lr)
  186. if config.nc_16 == 1:
  187. if config.model_parallel == 0:
  188. if config.num_classes % 16 == 0:
  189. config.logger.info('data parallel aleardy 16, nums: %d', config.num_classes)
  190. else:
  191. config.num_classes = (config.num_classes // 16 + 1) * 16
  192. else:
  193. if config.num_classes % (config.world_size * 16) == 0:
  194. config.logger.info('model parallel aleardy 16, nums: %d', config.num_classes)
  195. else:
  196. config.num_classes = (config.num_classes // (config.world_size * 16) + 1) * config.world_size * 16
  197. config.logger.info('for D, loaded, class nums: %d', config.num_classes)
  198. config.logger.info('steps_per_epoch: %d', config.steps_per_epoch)
  199. config.logger.info('img_total_num: %d', config.steps_per_epoch * config.per_batch_size)
  200. config.logger.info('get_backbone----in----')
  201. _backbone = get_backbone(config)
  202. config.logger.info('get_backbone----out----')
  203. config.logger.info('get_metric_fc----in----')
  204. margin_fc_1 = get_metric_fc(config)
  205. config.logger.info('get_metric_fc----out----')
  206. config.logger.info('DistributedHelper----in----')
  207. network_1 = DistributedHelper(_backbone, margin_fc_1)
  208. config.logger.info('DistributedHelper----out----')
  209. config.logger.info('network fp16----in----')
  210. if config.fp16 == 1:
  211. network_1.add_flags_recursive(fp16=True)
  212. config.logger.info('network fp16----out----')
  213. criterion_1 = get_loss(config)
  214. if config.fp16 == 1 and config.model_parallel == 0:
  215. criterion_1.add_flags_recursive(fp32=True)
  216. network_1 = load_pretrain(config, network_1)
  217. train_net = BuildTrainNetwork(network_1, criterion_1, config)
  218. # call warmup_step should behind the config steps_per_epoch
  219. config.lrs = warmup_step_list(config, gamma=0.1)
  220. lrs_gen = list_to_gen(config.lrs)
  221. opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=config.momentum,
  222. weight_decay=config.weight_decay)
  223. scale_manager = DynamicLossScaleManager(init_loss_scale=config.dynamic_init_loss_scale, scale_factor=2,
  224. scale_window=2000)
  225. model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager)
  226. save_checkpoint_steps = config.ckpt_steps
  227. config.logger.info('save_checkpoint_steps: %d', save_checkpoint_steps)
  228. if config.max_ckpts == -1:
  229. keep_checkpoint_max = int(config.steps_per_epoch * config.max_epoch / save_checkpoint_steps) + 5
  230. else:
  231. keep_checkpoint_max = config.max_ckpts
  232. config.logger.info('keep_checkpoint_max: %d', keep_checkpoint_max)
  233. ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max)
  234. config.logger.info('max_epoch_train: %d', config.max_epoch)
  235. ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=config.ckpt_path, prefix='{}'.format(config.local_rank))
  236. config.epoch_cnt = 0
  237. progress_cb = ProgressMonitor(config)
  238. new_epoch_train = config.max_epoch * steps_per_epoch // config.log_interval
  239. model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=config.log_interval)
  240. if __name__ == "__main__":
  241. run_train()