You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Face attribute train."""
  16. import os
  17. import time
  18. import datetime
  19. import argparse
  20. import mindspore
  21. import mindspore.nn as nn
  22. from mindspore import context
  23. from mindspore import Tensor
  24. from mindspore.nn.optim import Momentum
  25. from mindspore.communication.management import get_group_size, init, get_rank
  26. from mindspore.nn import TrainOneStepCell
  27. from mindspore.context import ParallelMode
  28. from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from mindspore.ops import operations as P
  31. from mindspore.common import dtype as mstype
  32. from src.FaceAttribute.resnet18 import get_resnet18
  33. from src.FaceAttribute.loss_factory import get_loss
  34. from src.dataset_train import data_generator
  35. from src.lrsche_factory import warmup_step
  36. from src.logging import get_logger, AverageMeter
  37. from src.config import config
  38. devid = int(os.getenv('DEVICE_ID'))
  39. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
  40. class BuildTrainNetwork(nn.Cell):
  41. '''Build train network.'''
  42. def __init__(self, my_network, my_criterion):
  43. super(BuildTrainNetwork, self).__init__()
  44. self.network = my_network
  45. self.criterion = my_criterion
  46. self.print = P.Print()
  47. def construct(self, input_data, label):
  48. logit0, logit1, logit2 = self.network(input_data)
  49. loss0 = self.criterion(logit0, logit1, logit2, label)
  50. return loss0
  51. def parse_args():
  52. '''Argument for Face Attributes.'''
  53. parser = argparse.ArgumentParser('Face Attributes')
  54. parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
  55. parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
  56. parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
  57. parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
  58. arg, _ = parser.parse_known_args()
  59. return arg
  60. if __name__ == "__main__":
  61. mindspore.set_seed(1)
  62. # logger
  63. args = parse_args()
  64. # init distributed
  65. if args.world_size != 1:
  66. init()
  67. args.local_rank = get_rank()
  68. args.world_size = get_group_size()
  69. args.per_batch_size = config.per_batch_size
  70. args.dst_h = config.dst_h
  71. args.dst_w = config.dst_w
  72. args.workers = config.workers
  73. args.attri_num = config.attri_num
  74. args.classes = config.classes
  75. args.backbone = config.backbone
  76. args.loss_scale = config.loss_scale
  77. args.flat_dim = config.flat_dim
  78. args.fc_dim = config.fc_dim
  79. args.lr = config.lr
  80. args.lr_scale = config.lr_scale
  81. args.lr_epochs = config.lr_epochs
  82. args.weight_decay = config.weight_decay
  83. args.momentum = config.momentum
  84. args.max_epoch = config.max_epoch
  85. args.warmup_epochs = config.warmup_epochs
  86. args.log_interval = config.log_interval
  87. args.ckpt_path = config.ckpt_path
  88. if args.world_size == 1:
  89. args.per_batch_size = 256
  90. else:
  91. args.lr = args.lr * 4.
  92. if args.world_size != 1:
  93. parallel_mode = ParallelMode.DATA_PARALLEL
  94. else:
  95. parallel_mode = ParallelMode.STAND_ALONE
  96. context.reset_auto_parallel_context()
  97. context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
  98. # model and log save path
  99. args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
  100. args.logger = get_logger(args.outputs_dir, args.local_rank)
  101. loss_meter = AverageMeter('loss')
  102. # dataloader
  103. args.logger.info('start create dataloader')
  104. de_dataloader, steps_per_epoch, num_classes = data_generator(args)
  105. args.steps_per_epoch = steps_per_epoch
  106. args.num_classes = num_classes
  107. args.logger.info('end create dataloader')
  108. args.logger.save_args(args)
  109. # backbone and loss
  110. args.logger.important_info('start create network')
  111. create_network_start = time.time()
  112. network = get_resnet18(args)
  113. criterion = get_loss()
  114. # load pretrain model
  115. if os.path.isfile(args.pretrained):
  116. param_dict = load_checkpoint(args.pretrained)
  117. param_dict_new = {}
  118. for key, values in param_dict.items():
  119. if key.startswith('moments.'):
  120. continue
  121. elif key.startswith('network.'):
  122. param_dict_new[key[8:]] = values
  123. else:
  124. param_dict_new[key] = values
  125. load_param_into_net(network, param_dict_new)
  126. args.logger.info('load model {} success'.format(args.pretrained))
  127. # optimizer and lr scheduler
  128. lr = warmup_step(args, gamma=0.1)
  129. opt = Momentum(params=network.trainable_params(),
  130. learning_rate=lr,
  131. momentum=args.momentum,
  132. weight_decay=args.weight_decay,
  133. loss_scale=args.loss_scale)
  134. train_net = BuildTrainNetwork(network, criterion)
  135. # mixed precision training
  136. criterion.add_flags_recursive(fp32=True)
  137. # package training process
  138. train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
  139. context.reset_auto_parallel_context()
  140. # checkpoint
  141. if args.local_rank == 0:
  142. ckpt_max_num = args.max_epoch
  143. train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
  144. ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
  145. cb_params = _InternalCallbackParam()
  146. cb_params.train_network = train_net
  147. cb_params.epoch_num = ckpt_max_num
  148. cb_params.cur_epoch_num = 0
  149. run_context = RunContext(cb_params)
  150. ckpt_cb.begin(run_context)
  151. train_net.set_train()
  152. t_end = time.time()
  153. t_epoch = time.time()
  154. old_progress = -1
  155. i = 0
  156. for _, (data, gt_classes) in enumerate(de_dataloader):
  157. data_tensor = Tensor(data, dtype=mstype.float32)
  158. gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
  159. loss = train_net(data_tensor, gt_tensor)
  160. loss_meter.update(loss.asnumpy()[0])
  161. # save ckpt
  162. if args.local_rank == 0:
  163. cb_params.cur_step_num = i + 1
  164. cb_params.batch_num = i + 2
  165. ckpt_cb.step_end(run_context)
  166. if i % args.steps_per_epoch == 0 and args.local_rank == 0:
  167. cb_params.cur_epoch_num += 1
  168. # save Log
  169. if i == 0:
  170. time_for_graph_compile = time.time() - create_network_start
  171. args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
  172. if i % args.log_interval == 0 and args.local_rank == 0:
  173. time_used = time.time() - t_end
  174. epoch = int(i / args.steps_per_epoch)
  175. fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
  176. args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
  177. t_end = time.time()
  178. loss_meter.reset()
  179. old_progress = i
  180. if i % args.steps_per_epoch == 0 and args.local_rank == 0:
  181. epoch_time_used = time.time() - t_epoch
  182. epoch = int(i / args.steps_per_epoch)
  183. fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
  184. args.logger.info('=================================================')
  185. args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
  186. args.logger.info('=================================================')
  187. t_epoch = time.time()
  188. i += 1
  189. args.logger.info('--------- trains out ---------')