You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Util class or function."""
  16. from mindspore.train.serialization import load_checkpoint
  17. import mindspore.nn as nn
  18. class AverageMeter:
  19. """Computes and stores the average and current value"""
  20. def __init__(self, name, fmt=':f', tb_writer=None):
  21. self.name = name
  22. self.fmt = fmt
  23. self.reset()
  24. self.tb_writer = tb_writer
  25. self.cur_step = 1
  26. self.val = 0
  27. self.avg = 0
  28. self.sum = 0
  29. self.count = 0
  30. def reset(self):
  31. self.val = 0
  32. self.avg = 0
  33. self.sum = 0
  34. self.count = 0
  35. def update(self, val, n=1):
  36. self.val = val
  37. self.sum += val * n
  38. self.count += n
  39. self.avg = self.sum / self.count
  40. if self.tb_writer is not None:
  41. self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
  42. self.cur_step += 1
  43. def __str__(self):
  44. fmtstr = '{name}:{avg' + self.fmt + '}'
  45. return fmtstr.format(**self.__dict__)
  46. def load_backbone(net, ckpt_path, args):
  47. """Load darknet53 backbone checkpoint."""
  48. param_dict = load_checkpoint(ckpt_path)
  49. yolo_backbone_prefix = 'feature_map.backbone'
  50. darknet_backbone_prefix = 'network.backbone'
  51. find_param = []
  52. not_found_param = []
  53. for name, cell in net.cells_and_names():
  54. if name.startswith(yolo_backbone_prefix):
  55. name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
  56. if isinstance(cell, (nn.Conv2d, nn.Dense)):
  57. darknet_weight = '{}.weight'.format(name)
  58. darknet_bias = '{}.bias'.format(name)
  59. if darknet_weight in param_dict:
  60. cell.weight.default_input = param_dict[darknet_weight].data
  61. find_param.append(darknet_weight)
  62. else:
  63. not_found_param.append(darknet_weight)
  64. if darknet_bias in param_dict:
  65. cell.bias.default_input = param_dict[darknet_bias].data
  66. find_param.append(darknet_bias)
  67. else:
  68. not_found_param.append(darknet_bias)
  69. elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
  70. darknet_moving_mean = '{}.moving_mean'.format(name)
  71. darknet_moving_variance = '{}.moving_variance'.format(name)
  72. darknet_gamma = '{}.gamma'.format(name)
  73. darknet_beta = '{}.beta'.format(name)
  74. if darknet_moving_mean in param_dict:
  75. cell.moving_mean.default_input = param_dict[darknet_moving_mean].data
  76. find_param.append(darknet_moving_mean)
  77. else:
  78. not_found_param.append(darknet_moving_mean)
  79. if darknet_moving_variance in param_dict:
  80. cell.moving_variance.default_input = param_dict[darknet_moving_variance].data
  81. find_param.append(darknet_moving_variance)
  82. else:
  83. not_found_param.append(darknet_moving_variance)
  84. if darknet_gamma in param_dict:
  85. cell.gamma.default_input = param_dict[darknet_gamma].data
  86. find_param.append(darknet_gamma)
  87. else:
  88. not_found_param.append(darknet_gamma)
  89. if darknet_beta in param_dict:
  90. cell.beta.default_input = param_dict[darknet_beta].data
  91. find_param.append(darknet_beta)
  92. else:
  93. not_found_param.append(darknet_beta)
  94. args.logger.info('================found_param {}========='.format(len(find_param)))
  95. args.logger.info(find_param)
  96. args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
  97. args.logger.info(not_found_param)
  98. args.logger.info('=====load {} successfully ====='.format(ckpt_path))
  99. return net
  100. def default_wd_filter(x):
  101. """default weight decay filter."""
  102. parameter_name = x.name
  103. if parameter_name.endswith('.bias'):
  104. # all bias not using weight decay
  105. return False
  106. if parameter_name.endswith('.gamma'):
  107. # bn weight bias not using weight decay, be carefully for now x not include BN
  108. return False
  109. if parameter_name.endswith('.beta'):
  110. # bn weight bias not using weight decay, be carefully for now x not include BN
  111. return False
  112. return True
  113. def get_param_groups(network):
  114. """Param groups for optimizer."""
  115. decay_params = []
  116. no_decay_params = []
  117. for x in network.trainable_params():
  118. parameter_name = x.name
  119. if parameter_name.endswith('.bias'):
  120. # all bias not using weight decay
  121. no_decay_params.append(x)
  122. elif parameter_name.endswith('.gamma'):
  123. # bn weight bias not using weight decay, be carefully for now x not include BN
  124. no_decay_params.append(x)
  125. elif parameter_name.endswith('.beta'):
  126. # bn weight bias not using weight decay, be carefully for now x not include BN
  127. no_decay_params.append(x)
  128. else:
  129. decay_params.append(x)
  130. return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
  131. class ShapeRecord:
  132. """Log image shape."""
  133. def __init__(self):
  134. self.shape_record = {
  135. 320: 0,
  136. 352: 0,
  137. 384: 0,
  138. 416: 0,
  139. 448: 0,
  140. 480: 0,
  141. 512: 0,
  142. 544: 0,
  143. 576: 0,
  144. 608: 0,
  145. 'total': 0
  146. }
  147. def set(self, shape):
  148. if len(shape) > 1:
  149. shape = shape[0]
  150. shape = int(shape)
  151. self.shape_record[shape] += 1
  152. self.shape_record['total'] += 1
  153. def show(self, logger):
  154. for key in self.shape_record:
  155. rate = self.shape_record[key] / float(self.shape_record['total'])
  156. logger.info('shape {}: {:.2f}%'.format(key, rate*100))