# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Face Recognition train.""" import os import time import argparse import datetime import warnings import random import numpy as np import mindspore from mindspore import context from mindspore import Tensor from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig from mindspore.nn.optim import SGD from mindspore.nn import TrainOneStepCell from mindspore.communication.management import get_group_size, init, get_rank from src.dataset import get_de_dataset from src.config import reid_1p_cfg_ascend, reid_1p_cfg, reid_8p_cfg_ascend, reid_8p_cfg_gpu from src.lr_generator import step_lr from src.log import get_logger, AverageMeter from src.reid import SphereNet, CombineMarginFCFp16, BuildTrainNetworkWithHead, CombineMarginFC from src.loss import CrossEntropy warnings.filterwarnings('ignore') random.seed(1) np.random.seed(1) def init_argument(): """init config argument.""" parser = argparse.ArgumentParser(description='Face Recognition For Tracking') parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend', help='device_target') parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') parser.add_argument('--data_dir', type=str, default='', help='image folders') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') args = parser.parse_args() graph_path = os.path.join('./graphs_graphmode', datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, save_graphs_path=graph_path) if args.device_target == 'Ascend': devid = int(os.getenv('DEVICE_ID')) context.set_context(device_id=devid) if args.is_distributed == 0: if args.device_target == 'Ascend': cfg = reid_1p_cfg_ascend else: cfg = reid_1p_cfg else: if args.device_target == 'Ascend': cfg = reid_8p_cfg_ascend else: cfg = reid_8p_cfg_gpu cfg.pretrained = args.pretrained cfg.data_dir = args.data_dir # Init distributed if args.is_distributed: init() cfg.local_rank = get_rank() cfg.world_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL else: parallel_mode = ParallelMode.STAND_ALONE # parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size, gradients_mean=True) mindspore.common.set_seed(1) # logger cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank) # Show cfg cfg.logger.save_args(cfg) return cfg, args def main(): cfg, args = init_argument() loss_meter = AverageMeter('loss') # dataloader cfg.logger.info('start create dataloader') de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg) cfg.steps_per_epoch = steps_per_epoch cfg.logger.info('step per epoch: %s', cfg.steps_per_epoch) de_dataloader = de_dataset.create_tuple_iterator() cfg.logger.info('class num original: %s', class_num) if class_num % 16 != 0: class_num = (class_num // 16 + 1) * 16 cfg.class_num = class_num cfg.logger.info('change the class num to: %s', cfg.class_num) cfg.logger.info('end create dataloader') # backbone and loss cfg.logger.important_info('start create network') create_network_start = time.time() network = SphereNet(num_layers=cfg.net_depth, feature_dim=cfg.embedding_size, shape=cfg.input_size) if args.device_target == 'CPU': head = CombineMarginFC(embbeding_size=cfg.embedding_size, classnum=cfg.class_num) else: head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size, classnum=cfg.class_num) criterion = CrossEntropy() # load the pretrained model if os.path.isfile(cfg.pretrained): param_dict = load_checkpoint(cfg.pretrained) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = values else: param_dict_new[key] = values load_param_into_net(network, param_dict_new) cfg.logger.info('load model %s success', cfg.pretrained) # mixed precision training if args.device_target == 'CPU': network.add_flags_recursive(fp32=True) head.add_flags_recursive(fp32=True) else: network.add_flags_recursive(fp16=True) head.add_flags_recursive(fp16=True) criterion.add_flags_recursive(fp32=True) train_net = BuildTrainNetworkWithHead(network, head, criterion) # optimizer and lr scheduler lr = step_lr(lr=cfg.lr, epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, max_epoch=cfg.max_epoch, gamma=cfg.lr_gamma) opt = SGD(params=train_net.trainable_params(), learning_rate=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale) # package training process, adjust lr + forward + backward + optimizer train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale) # checkpoint save if cfg.local_rank == 0: ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num) ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank)) cb_params = _InternalCallbackParam() cb_params.train_network = train_net cb_params.epoch_num = ckpt_max_num cb_params.cur_epoch_num = 1 run_context = RunContext(cb_params) ckpt_cb.begin(run_context) train_net.set_train() t_end = time.time() t_epoch = time.time() old_progress = -1 cfg.logger.important_info('====start train====') for i, total_data in enumerate(de_dataloader): data, gt = total_data data = Tensor(data) gt = Tensor(gt) loss = train_net(data, gt) loss_meter.update(loss.asnumpy()) # ckpt if cfg.local_rank == 0: cb_params.cur_step_num = i + 1 # current step number cb_params.batch_num = i + 2 ckpt_cb.step_end(run_context) # logging loss, fps, ... if i == 0: time_for_graph_compile = time.time() - create_network_start cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile)) if i % cfg.log_interval == 0 and cfg.local_rank == 0: time_used = time.time() - t_end epoch = int(i / cfg.steps_per_epoch) fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr={}'.format(epoch, i, loss_meter, fps, lr[i])) t_end = time.time() loss_meter.reset() old_progress = i if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0: epoch_time_used = time.time() - t_epoch epoch = int(i / cfg.steps_per_epoch) fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used cfg.logger.info('=================================================') cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps)) cfg.logger.info('=================================================') t_epoch = time.time() cfg.logger.important_info('====train end====') if __name__ == "__main__": main()