|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """auxiliary functions for train, to print and preload"""
-
- import math
- import logging
- import os
- import sys
- from datetime import datetime
- import numpy as np
-
- from mindspore.train.serialization import load_checkpoint
- import mindspore.nn as nn
-
- def load_backbone(net, ckpt_path, args):
- """
- Load backbone
- """
- param_dict = load_checkpoint(ckpt_path)
- centerface_backbone_prefix = 'base'
- mobilev2_backbone_prefix = 'network.backbone'
- find_param = []
- not_found_param = []
-
- def replace_names(name, replace_name, replace_idx):
- names = name.split('.')
- if len(names) < 4:
- raise "centerface_backbone_prefix name too short"
- tmp = names[2] + '.' + names[3]
- if replace_name != tmp:
- replace_name = tmp
- replace_idx += 1
- name = name.replace(replace_name, 'features' + '.' + str(replace_idx))
- return name, replace_name, replace_idx
-
- replace_name = 'need_fp1.0'
- replace_idx = 0
- for name, cell in net.cells_and_names():
- if name.startswith(centerface_backbone_prefix):
- name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix)
- if isinstance(cell, (nn.Conv2d, nn.Dense)):
- name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
- mobilev2_weight = '{}.weight'.format(name)
- mobilev2_bias = '{}.bias'.format(name)
- if mobilev2_weight in param_dict:
- cell.weight.set_data(param_dict[mobilev2_weight].data)
- find_param.append(mobilev2_weight)
- else:
- not_found_param.append(mobilev2_weight)
- if mobilev2_bias in param_dict:
- cell.bias.set_data(param_dict[mobilev2_bias].data)
- find_param.append(mobilev2_bias)
- else:
- not_found_param.append(mobilev2_bias)
- elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
- name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
- mobilev2_moving_mean = '{}.moving_mean'.format(name)
- mobilev2_moving_variance = '{}.moving_variance'.format(name)
- mobilev2_gamma = '{}.gamma'.format(name)
- mobilev2_beta = '{}.beta'.format(name)
- if mobilev2_moving_mean in param_dict:
- cell.moving_mean.set_data(param_dict[mobilev2_moving_mean].data)
- find_param.append(mobilev2_moving_mean)
- else:
- not_found_param.append(mobilev2_moving_mean)
- if mobilev2_moving_variance in param_dict:
- cell.moving_variance.set_data(param_dict[mobilev2_moving_variance].data)
- find_param.append(mobilev2_moving_variance)
- else:
- not_found_param.append(mobilev2_moving_variance)
- if mobilev2_gamma in param_dict:
- cell.gamma.set_data(param_dict[mobilev2_gamma].data)
- find_param.append(mobilev2_gamma)
- else:
- not_found_param.append(mobilev2_gamma)
- if mobilev2_beta in param_dict:
- cell.beta.set_data(param_dict[mobilev2_beta].data)
- find_param.append(mobilev2_beta)
- else:
- not_found_param.append(mobilev2_beta)
-
- args.logger.info('================found_param {}========='.format(len(find_param)))
- args.logger.info(find_param)
- args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
- args.logger.info(not_found_param)
- args.logger.info('=====load {} successfully ====='.format(ckpt_path))
-
- return net
-
- def get_param_groups(network):
- """
- Get param groups
- """
- decay_params = []
- no_decay_params = []
- for x in network.trainable_params():
- parameter_name = x.name
- if parameter_name.endswith('.bias'):
- # all bias not using weight decay
- # print('no decay:{}'.format(parameter_name))
- no_decay_params.append(x)
- elif parameter_name.endswith('.gamma'):
- # bn weight bias not using weight decay, be carefully for now x not include BN
- # print('no decay:{}'.format(parameter_name))
- no_decay_params.append(x)
- elif parameter_name.endswith('.beta'):
- # bn weight bias not using weight decay, be carefully for now x not include BN
- # print('no decay:{}'.format(parameter_name))
- no_decay_params.append(x)
- else:
- decay_params.append(x)
-
- return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
-
-
- class DistributedSampler():
- """
- Distributed sampler
- """
- def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
- self.dataset = dataset
- self.rank = rank
- self.group_size = group_size
- self.dataset_length = len(self.dataset)
- self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
- self.total_size = self.num_samples * self.group_size
- self.shuffle = shuffle
- self.seed = seed
-
- def __iter__(self):
- if self.shuffle:
- self.seed = (self.seed + 1) & 0xffffffff
- np.random.seed(self.seed)
- indices = np.random.permutation(self.dataset_length).tolist()
- else:
- indices = list(range(len(self.dataset.classes)))
-
- indices += indices[:(self.total_size - len(indices))]
- assert len(indices) == self.total_size
-
- indices = indices[self.rank::self.group_size]
- assert len(indices) == self.num_samples
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
-
-
- class AverageMeter():
- """Computes and stores the average and current value"""
-
- def __init__(self, name, fmt=':f', tb_writer=None):
- self.name = name
- self.fmt = fmt
- self.reset()
- self.tb_writer = tb_writer
- self.cur_step = 1
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
- if self.tb_writer is not None:
- self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
- self.cur_step += 1
-
- def __str__(self):
- fmtstr = '{name}:{avg' + self.fmt + '}'
- return fmtstr.format(**self.__dict__)
-
- class LOGGER(logging.Logger):
- """
- Logger class
- """
- def __init__(self, logger_name, rank=0):
- super(LOGGER, self).__init__(logger_name)
- if rank % 8 == 0:
- console = logging.StreamHandler(sys.stdout)
- console.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
- console.setFormatter(formatter)
- self.addHandler(console)
-
- def setup_logging_file(self, log_dir, rank=0):
- """
- Setup logging file
- """
- self.rank = rank
- if not os.path.exists(log_dir):
- os.makedirs(log_dir, exist_ok=True)
- log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
- self.log_fn = os.path.join(log_dir, log_name)
- fh = logging.FileHandler(self.log_fn)
- fh.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
- fh.setFormatter(formatter)
- self.addHandler(fh)
-
- def info(self, msg, *args, **kwargs):
- if self.isEnabledFor(logging.INFO):
- self._log(logging.INFO, msg, args, **kwargs)
-
- def save_args(self, args):
- self.info('Args:')
- args_dict = vars(args)
- for key in args_dict.keys():
- # self.info('--> {}: {}'.format(key, args_dict[key]))
- self.info('--> %s', key)
- self.info('')
-
- def important_info(self, msg, *args, **kwargs):
- if self.isEnabledFor(logging.INFO) and self.rank == 0:
- line_width = 2
- important_msg = '\n'
- important_msg += ('*'*70 + '\n')*line_width
- important_msg += ('*'*line_width + '\n')*2
- important_msg += '*'*line_width + ' '*8 + msg + '\n'
- important_msg += ('*'*line_width + '\n')*2
- important_msg += ('*'*70 + '\n')*line_width
- self.info(important_msg, *args, **kwargs)
-
- def get_logger(path, rank):
- logger = LOGGER("centerface", rank)
- logger.setup_logging_file(path, rank)
- return logger
|