You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face attribute train."""
  16. import os
  17. import time
  18. import datetime
  19. import argparse
  20. import mindspore.nn as nn
  21. from mindspore import context
  22. from mindspore import Tensor
  23. from mindspore.nn.optim import Momentum
  24. from mindspore.communication.management import get_group_size, init, get_rank
  25. from mindspore.nn import TrainOneStepCell
  26. from mindspore.context import ParallelMode
  27. from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
  28. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  29. from mindspore.ops import operations as P
  30. from mindspore.common import dtype as mstype
  31. from src.FaceAttribute.resnet18 import get_resnet18
  32. from src.FaceAttribute.loss_factory import get_loss
  33. from src.dataset_train import data_generator
  34. from src.lrsche_factory import warmup_step
  35. from src.logging import get_logger, AverageMeter
  36. from src.config import config
  37. devid = int(os.getenv('DEVICE_ID'))
  38. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
  39. class BuildTrainNetwork(nn.Cell):
  40. '''Build train network.'''
  41. def __init__(self, network, criterion):
  42. super(BuildTrainNetwork, self).__init__()
  43. self.network = network
  44. self.criterion = criterion
  45. self.print = P.Print()
  46. def construct(self, input_data, label):
  47. logit0, logit1, logit2 = self.network(input_data)
  48. loss = self.criterion(logit0, logit1, logit2, label)
  49. return loss
  50. def parse_args():
  51. '''Argument for Face Attributes.'''
  52. parser = argparse.ArgumentParser('Face Attributes')
  53. parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
  54. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  55. parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
  56. parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
  57. args, _ = parser.parse_known_args()
  58. return args
  59. def train():
  60. '''train function.'''
  61. # logger
  62. args = parse_args()
  63. # init distributed
  64. if args.world_size != 1:
  65. init()
  66. args.local_rank = get_rank()
  67. args.world_size = get_group_size()
  68. args.per_batch_size = config.per_batch_size
  69. args.dst_h = config.dst_h
  70. args.dst_w = config.dst_w
  71. args.workers = config.workers
  72. args.attri_num = config.attri_num
  73. args.classes = config.classes
  74. args.backbone = config.backbone
  75. args.loss_scale = config.loss_scale
  76. args.flat_dim = config.flat_dim
  77. args.fc_dim = config.fc_dim
  78. args.lr = config.lr
  79. args.lr_scale = config.lr_scale
  80. args.lr_epochs = config.lr_epochs
  81. args.weight_decay = config.weight_decay
  82. args.momentum = config.momentum
  83. args.max_epoch = config.max_epoch
  84. args.warmup_epochs = config.warmup_epochs
  85. args.log_interval = config.log_interval
  86. args.ckpt_path = config.ckpt_path
  87. if args.world_size == 1:
  88. args.per_batch_size = 256
  89. else:
  90. args.lr = args.lr * 4.
  91. if args.world_size != 1:
  92. parallel_mode = ParallelMode.DATA_PARALLEL
  93. else:
  94. parallel_mode = ParallelMode.STAND_ALONE
  95. context.reset_auto_parallel_context()
  96. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
  97. # model and log save path
  98. args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  99. args.logger = get_logger(args.outputs_dir, args.local_rank)
  100. loss_meter = AverageMeter('loss')
  101. # dataloader
  102. args.logger.info('start create dataloader')
  103. de_dataloader, steps_per_epoch, num_classes = data_generator(args)
  104. args.steps_per_epoch = steps_per_epoch
  105. args.num_classes = num_classes
  106. args.logger.info('end create dataloader')
  107. args.logger.save_args(args)
  108. # backbone and loss
  109. args.logger.important_info('start create network')
  110. create_network_start = time.time()
  111. network = get_resnet18(args)
  112. criterion = get_loss()
  113. # load pretrain model
  114. if os.path.isfile(args.pretrained):
  115. param_dict = load_checkpoint(args.pretrained)
  116. param_dict_new = {}
  117. for key, values in param_dict.items():
  118. if key.startswith('moments.'):
  119. continue
  120. elif key.startswith('network.'):
  121. param_dict_new[key[8:]] = values
  122. else:
  123. param_dict_new[key] = values
  124. load_param_into_net(network, param_dict_new)
  125. args.logger.info('load model {} success'.format(args.pretrained))
  126. # optimizer and lr scheduler
  127. lr = warmup_step(args, gamma=0.1)
  128. opt = Momentum(params=network.trainable_params(),
  129. learning_rate=lr,
  130. momentum=args.momentum,
  131. weight_decay=args.weight_decay,
  132. loss_scale=args.loss_scale)
  133. train_net = BuildTrainNetwork(network, criterion)
  134. # mixed precision training
  135. criterion.add_flags_recursive(fp32=True)
  136. # package training process
  137. train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
  138. context.reset_auto_parallel_context()
  139. # checkpoint
  140. if args.local_rank == 0:
  141. ckpt_max_num = args.max_epoch
  142. train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
  143. ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
  144. cb_params = _InternalCallbackParam()
  145. cb_params.train_network = train_net
  146. cb_params.epoch_num = ckpt_max_num
  147. cb_params.cur_epoch_num = 0
  148. run_context = RunContext(cb_params)
  149. ckpt_cb.begin(run_context)
  150. train_net.set_train()
  151. t_end = time.time()
  152. t_epoch = time.time()
  153. old_progress = -1
  154. i = 0
  155. for _, (data, gt_classes) in enumerate(de_dataloader):
  156. data_tensor = Tensor(data, dtype=mstype.float32)
  157. gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
  158. loss = train_net(data_tensor, gt_tensor)
  159. loss_meter.update(loss.asnumpy()[0])
  160. # save ckpt
  161. if args.local_rank == 0:
  162. cb_params.cur_step_num = i + 1
  163. cb_params.batch_num = i + 2
  164. ckpt_cb.step_end(run_context)
  165. if i % args.steps_per_epoch == 0 and args.local_rank == 0:
  166. cb_params.cur_epoch_num += 1
  167. # save Log
  168. if i == 0:
  169. time_for_graph_compile = time.time() - create_network_start
  170. args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
  171. if i % args.log_interval == 0 and args.local_rank == 0:
  172. time_used = time.time() - t_end
  173. epoch = int(i / args.steps_per_epoch)
  174. fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
  175. args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
  176. t_end = time.time()
  177. loss_meter.reset()
  178. old_progress = i
  179. if i % args.steps_per_epoch == 0 and args.local_rank == 0:
  180. epoch_time_used = time.time() - t_epoch
  181. epoch = int(i / args.steps_per_epoch)
  182. fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
  183. args.logger.info('=================================================')
  184. args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
  185. args.logger.info('=================================================')
  186. t_epoch = time.time()
  187. i += 1
  188. args.logger.info('--------- trains out ---------')
  189. if __name__ == "__main__":
  190. train()