You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """auxiliary functions for train, to print and preload"""
  16. import math
  17. import logging
  18. import os
  19. import sys
  20. from datetime import datetime
  21. import numpy as np
  22. from mindspore.train.serialization import load_checkpoint
  23. import mindspore.nn as nn
  24. def load_backbone(net, ckpt_path, args):
  25. """
  26. Load backbone
  27. """
  28. param_dict = load_checkpoint(ckpt_path)
  29. centerface_backbone_prefix = 'base'
  30. mobilev2_backbone_prefix = 'network.backbone'
  31. find_param = []
  32. not_found_param = []
  33. def replace_names(name, replace_name, replace_idx):
  34. names = name.split('.')
  35. if len(names) < 4:
  36. raise "centerface_backbone_prefix name too short"
  37. tmp = names[2] + '.' + names[3]
  38. if replace_name != tmp:
  39. replace_name = tmp
  40. replace_idx += 1
  41. name = name.replace(replace_name, 'features' + '.' + str(replace_idx))
  42. return name, replace_name, replace_idx
  43. replace_name = 'need_fp1.0'
  44. replace_idx = 0
  45. for name, cell in net.cells_and_names():
  46. if name.startswith(centerface_backbone_prefix):
  47. name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix)
  48. if isinstance(cell, (nn.Conv2d, nn.Dense)):
  49. name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
  50. mobilev2_weight = '{}.weight'.format(name)
  51. mobilev2_bias = '{}.bias'.format(name)
  52. if mobilev2_weight in param_dict:
  53. cell.weight.set_data(param_dict[mobilev2_weight].data)
  54. find_param.append(mobilev2_weight)
  55. else:
  56. not_found_param.append(mobilev2_weight)
  57. if mobilev2_bias in param_dict:
  58. cell.bias.set_data(param_dict[mobilev2_bias].data)
  59. find_param.append(mobilev2_bias)
  60. else:
  61. not_found_param.append(mobilev2_bias)
  62. elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
  63. name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx)
  64. mobilev2_moving_mean = '{}.moving_mean'.format(name)
  65. mobilev2_moving_variance = '{}.moving_variance'.format(name)
  66. mobilev2_gamma = '{}.gamma'.format(name)
  67. mobilev2_beta = '{}.beta'.format(name)
  68. if mobilev2_moving_mean in param_dict:
  69. cell.moving_mean.set_data(param_dict[mobilev2_moving_mean].data)
  70. find_param.append(mobilev2_moving_mean)
  71. else:
  72. not_found_param.append(mobilev2_moving_mean)
  73. if mobilev2_moving_variance in param_dict:
  74. cell.moving_variance.set_data(param_dict[mobilev2_moving_variance].data)
  75. find_param.append(mobilev2_moving_variance)
  76. else:
  77. not_found_param.append(mobilev2_moving_variance)
  78. if mobilev2_gamma in param_dict:
  79. cell.gamma.set_data(param_dict[mobilev2_gamma].data)
  80. find_param.append(mobilev2_gamma)
  81. else:
  82. not_found_param.append(mobilev2_gamma)
  83. if mobilev2_beta in param_dict:
  84. cell.beta.set_data(param_dict[mobilev2_beta].data)
  85. find_param.append(mobilev2_beta)
  86. else:
  87. not_found_param.append(mobilev2_beta)
  88. args.logger.info('================found_param {}========='.format(len(find_param)))
  89. args.logger.info(find_param)
  90. args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
  91. args.logger.info(not_found_param)
  92. args.logger.info('=====load {} successfully ====='.format(ckpt_path))
  93. return net
  94. def get_param_groups(network):
  95. """
  96. Get param groups
  97. """
  98. decay_params = []
  99. no_decay_params = []
  100. for x in network.trainable_params():
  101. parameter_name = x.name
  102. if parameter_name.endswith('.bias'):
  103. # all bias not using weight decay
  104. # print('no decay:{}'.format(parameter_name))
  105. no_decay_params.append(x)
  106. elif parameter_name.endswith('.gamma'):
  107. # bn weight bias not using weight decay, be carefully for now x not include BN
  108. # print('no decay:{}'.format(parameter_name))
  109. no_decay_params.append(x)
  110. elif parameter_name.endswith('.beta'):
  111. # bn weight bias not using weight decay, be carefully for now x not include BN
  112. # print('no decay:{}'.format(parameter_name))
  113. no_decay_params.append(x)
  114. else:
  115. decay_params.append(x)
  116. return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
  117. class DistributedSampler():
  118. """
  119. Distributed sampler
  120. """
  121. def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
  122. self.dataset = dataset
  123. self.rank = rank
  124. self.group_size = group_size
  125. self.dataset_length = len(self.dataset)
  126. self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
  127. self.total_size = self.num_samples * self.group_size
  128. self.shuffle = shuffle
  129. self.seed = seed
  130. def __iter__(self):
  131. if self.shuffle:
  132. self.seed = (self.seed + 1) & 0xffffffff
  133. np.random.seed(self.seed)
  134. indices = np.random.permutation(self.dataset_length).tolist()
  135. else:
  136. indices = list(range(len(self.dataset.classes)))
  137. indices += indices[:(self.total_size - len(indices))]
  138. assert len(indices) == self.total_size
  139. indices = indices[self.rank::self.group_size]
  140. assert len(indices) == self.num_samples
  141. return iter(indices)
  142. def __len__(self):
  143. return self.num_samples
  144. class AverageMeter():
  145. """Computes and stores the average and current value"""
  146. def __init__(self, name, fmt=':f', tb_writer=None):
  147. self.name = name
  148. self.fmt = fmt
  149. self.reset()
  150. self.tb_writer = tb_writer
  151. self.cur_step = 1
  152. def reset(self):
  153. self.val = 0
  154. self.avg = 0
  155. self.sum = 0
  156. self.count = 0
  157. def update(self, val, n=1):
  158. self.val = val
  159. self.sum += val * n
  160. self.count += n
  161. self.avg = self.sum / self.count
  162. if self.tb_writer is not None:
  163. self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
  164. self.cur_step += 1
  165. def __str__(self):
  166. fmtstr = '{name}:{avg' + self.fmt + '}'
  167. return fmtstr.format(**self.__dict__)
  168. class LOGGER(logging.Logger):
  169. """
  170. Logger class
  171. """
  172. def __init__(self, logger_name, rank=0):
  173. super(LOGGER, self).__init__(logger_name)
  174. if rank % 8 == 0:
  175. console = logging.StreamHandler(sys.stdout)
  176. console.setLevel(logging.INFO)
  177. formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
  178. console.setFormatter(formatter)
  179. self.addHandler(console)
  180. def setup_logging_file(self, log_dir, rank=0):
  181. """
  182. Setup logging file
  183. """
  184. self.rank = rank
  185. if not os.path.exists(log_dir):
  186. os.makedirs(log_dir, exist_ok=True)
  187. log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
  188. self.log_fn = os.path.join(log_dir, log_name)
  189. fh = logging.FileHandler(self.log_fn)
  190. fh.setLevel(logging.INFO)
  191. formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
  192. fh.setFormatter(formatter)
  193. self.addHandler(fh)
  194. def info(self, msg, *args, **kwargs):
  195. if self.isEnabledFor(logging.INFO):
  196. self._log(logging.INFO, msg, args, **kwargs)
  197. def save_args(self, args):
  198. self.info('Args:')
  199. args_dict = vars(args)
  200. for key in args_dict.keys():
  201. # self.info('--> {}: {}'.format(key, args_dict[key]))
  202. self.info('--> %s', key)
  203. self.info('')
  204. def important_info(self, msg, *args, **kwargs):
  205. if self.isEnabledFor(logging.INFO) and self.rank == 0:
  206. line_width = 2
  207. important_msg = '\n'
  208. important_msg += ('*'*70 + '\n')*line_width
  209. important_msg += ('*'*line_width + '\n')*2
  210. important_msg += '*'*line_width + ' '*8 + msg + '\n'
  211. important_msg += ('*'*line_width + '\n')*2
  212. important_msg += ('*'*70 + '\n')*line_width
  213. self.info(important_msg, *args, **kwargs)
  214. def get_logger(path, rank):
  215. logger = LOGGER("centerface", rank)
  216. logger.setup_logging_file(path, rank)
  217. return logger