You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face Quality Assessment train."""
  16. import os
  17. import time
  18. import datetime
  19. import argparse
  20. import warnings
  21. import numpy as np
  22. import mindspore
  23. from mindspore import context
  24. from mindspore import Tensor
  25. from mindspore.context import ParallelMode
  26. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  27. from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
  28. from mindspore.nn import TrainOneStepCell
  29. from mindspore.nn.optim import Momentum
  30. from mindspore.communication.management import get_group_size, init, get_rank
  31. from src.loss import CriterionsFaceQA
  32. from src.config import faceqa_1p_cfg, faceqa_8p_cfg
  33. from src.face_qa import FaceQABackbone, BuildTrainNetwork
  34. from src.lr_generator import warmup_step
  35. from src.dataset import faceqa_dataset
  36. from src.log import get_logger, AverageMeter
  37. warnings.filterwarnings('ignore')
  38. devid = int(os.getenv('DEVICE_ID'))
  39. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
  40. mindspore.common.seed.set_seed(1)
  41. def main(args):
  42. if args.is_distributed == 0:
  43. cfg = faceqa_1p_cfg
  44. else:
  45. cfg = faceqa_8p_cfg
  46. cfg.data_lst = args.train_label_file
  47. cfg.pretrained = args.pretrained
  48. # Init distributed
  49. if args.is_distributed:
  50. init()
  51. cfg.local_rank = get_rank()
  52. cfg.world_size = get_group_size()
  53. parallel_mode = ParallelMode.DATA_PARALLEL
  54. else:
  55. parallel_mode = ParallelMode.STAND_ALONE
  56. # parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
  57. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
  58. gradients_mean=True)
  59. mindspore.common.set_seed(1)
  60. # logger
  61. cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  62. cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
  63. loss_meter = AverageMeter('loss')
  64. # Dataloader
  65. cfg.logger.info('start create dataloader')
  66. de_dataset = faceqa_dataset(imlist=cfg.data_lst, local_rank=cfg.local_rank, world_size=cfg.world_size,
  67. per_batch_size=cfg.per_batch_size)
  68. cfg.steps_per_epoch = de_dataset.get_dataset_size()
  69. de_dataset = de_dataset.repeat(cfg.max_epoch)
  70. de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
  71. # Show cfg
  72. cfg.logger.save_args(cfg)
  73. cfg.logger.info('end create dataloader')
  74. # backbone and loss
  75. cfg.logger.important_info('start create network')
  76. create_network_start = time.time()
  77. network = FaceQABackbone()
  78. criterion = CriterionsFaceQA()
  79. # load pretrain model
  80. if os.path.isfile(cfg.pretrained):
  81. param_dict = load_checkpoint(cfg.pretrained)
  82. param_dict_new = {}
  83. for key, values in param_dict.items():
  84. if key.startswith('moments.'):
  85. continue
  86. elif key.startswith('network.'):
  87. param_dict_new[key[8:]] = values
  88. else:
  89. param_dict_new[key] = values
  90. load_param_into_net(network, param_dict_new)
  91. cfg.logger.info('load model %s success.', cfg.pretrained)
  92. # optimizer and lr scheduler
  93. lr = warmup_step(cfg, gamma=0.9)
  94. opt = Momentum(params=network.trainable_params(),
  95. learning_rate=lr,
  96. momentum=cfg.momentum,
  97. weight_decay=cfg.weight_decay,
  98. loss_scale=cfg.loss_scale)
  99. # package training process, adjust lr + forward + backward + optimizer
  100. train_net = BuildTrainNetwork(network, criterion)
  101. train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale,)
  102. # checkpoint save
  103. if cfg.local_rank == 0:
  104. ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
  105. train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
  106. ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
  107. cb_params = _InternalCallbackParam()
  108. cb_params.train_network = train_net
  109. cb_params.epoch_num = ckpt_max_num
  110. cb_params.cur_epoch_num = 1
  111. run_context = RunContext(cb_params)
  112. ckpt_cb.begin(run_context)
  113. train_net.set_train()
  114. t_end = time.time()
  115. t_epoch = time.time()
  116. old_progress = -1
  117. cfg.logger.important_info('====start train====')
  118. for i, (data, gt) in enumerate(de_dataloader):
  119. # clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
  120. data = data.astype(np.float32)
  121. gt = gt.astype(np.float32)
  122. data = Tensor(data)
  123. gt = Tensor(gt)
  124. loss = train_net(data, gt)
  125. loss_meter.update(loss.asnumpy())
  126. # ckpt
  127. if cfg.local_rank == 0:
  128. cb_params.cur_step_num = i + 1 # current step number
  129. cb_params.batch_num = i + 2
  130. ckpt_cb.step_end(run_context)
  131. # logging loss, fps, ...
  132. if i == 0:
  133. time_for_graph_compile = time.time() - create_network_start
  134. cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
  135. if i % cfg.log_interval == 0 and cfg.local_rank == 0:
  136. time_used = time.time() - t_end
  137. epoch = int(i / cfg.steps_per_epoch)
  138. fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
  139. cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
  140. t_end = time.time()
  141. loss_meter.reset()
  142. old_progress = i
  143. if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
  144. epoch_time_used = time.time() - t_epoch
  145. epoch = int(i / cfg.steps_per_epoch)
  146. fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
  147. cfg.logger.info('=================================================')
  148. cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
  149. cfg.logger.info('=================================================')
  150. t_epoch = time.time()
  151. cfg.logger.important_info('====train end====')
  152. if __name__ == "__main__":
  153. parser = argparse.ArgumentParser(description='Face Quality Assessment')
  154. parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
  155. parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt')
  156. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  157. arg = parser.parse_args()
  158. main(arg)