You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

initializer.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Parameter init."""
  16. import math
  17. from functools import reduce
  18. import numpy as np
  19. from mindspore.common import initializer as init
  20. from mindspore.common.initializer import Initializer as MeInitializer
  21. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  22. import mindspore.nn as nn
  23. from .util import load_backbone
  24. def calculate_gain(nonlinearity, param=None):
  25. r"""Return the recommended gain value for the given nonlinearity function.
  26. The values are as follows:
  27. ================= ====================================================
  28. nonlinearity gain
  29. ================= ====================================================
  30. Linear / Identity :math:`1`
  31. Conv{1,2,3}D :math:`1`
  32. Sigmoid :math:`1`
  33. Tanh :math:`\frac{5}{3}`
  34. ReLU :math:`\sqrt{2}`
  35. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  36. ================= ====================================================
  37. Args:
  38. nonlinearity: the non-linear function (`nn.functional` name)
  39. param: optional parameter for the non-linear function
  40. Examples:
  41. >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
  42. """
  43. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  44. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  45. return 1
  46. if nonlinearity == 'tanh':
  47. return 5.0 / 3
  48. if nonlinearity == 'relu':
  49. return math.sqrt(2.0)
  50. if nonlinearity == 'leaky_relu':
  51. if param is None:
  52. negative_slope = 0.01
  53. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  54. # True/False are instances of int, hence check above
  55. negative_slope = param
  56. else:
  57. raise ValueError("negative_slope {} not a valid number".format(param))
  58. return math.sqrt(2.0 / (1 + negative_slope ** 2))
  59. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  60. def _assignment(arr, num):
  61. """Assign the value of 'num' and 'arr'."""
  62. if arr.shape == ():
  63. arr = arr.reshape((1))
  64. arr[:] = num
  65. arr = arr.reshape(())
  66. else:
  67. if isinstance(num, np.ndarray):
  68. arr[:] = num[:]
  69. else:
  70. arr[:] = num
  71. return arr
  72. def _calculate_correct_fan(array, mode):
  73. mode = mode.lower()
  74. valid_modes = ['fan_in', 'fan_out']
  75. if mode not in valid_modes:
  76. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  77. fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
  78. return fan_in if mode == 'fan_in' else fan_out
  79. def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  80. r"""Fills the input `Tensor` with values according to the method
  81. described in `Delving deep into rectifiers: Surpassing human-level
  82. performance on ImageNet classification` - He, K. et al. (2015), using a
  83. uniform distribution. The resulting tensor will have values sampled from
  84. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  85. .. math::
  86. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  87. Also known as He initialization.
  88. Args:
  89. tensor: an n-dimensional `Tensor`
  90. a: the negative slope of the rectifier used after this layer (only
  91. used with ``'leaky_relu'``)
  92. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  93. preserves the magnitude of the variance of the weights in the
  94. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  95. backwards pass.
  96. nonlinearity: the non-linear function (`nn.functional` name),
  97. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  98. Examples:
  99. >>> w = np.empty(3, 5)
  100. >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
  101. """
  102. fan = _calculate_correct_fan(arr, mode)
  103. gain = calculate_gain(nonlinearity, a)
  104. std = gain / math.sqrt(fan)
  105. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  106. return np.random.uniform(-bound, bound, arr.shape)
  107. def _calculate_fan_in_and_fan_out(arr):
  108. """Calculate fan in and fan out."""
  109. dimensions = len(arr.shape)
  110. if dimensions < 2:
  111. raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
  112. num_input_fmaps = arr.shape[1]
  113. num_output_fmaps = arr.shape[0]
  114. receptive_field_size = 1
  115. if dimensions > 2:
  116. receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
  117. fan_in = num_input_fmaps * receptive_field_size
  118. fan_out = num_output_fmaps * receptive_field_size
  119. return fan_in, fan_out
  120. class KaimingUniform(MeInitializer):
  121. """Kaiming uniform initializer."""
  122. def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
  123. super(KaimingUniform, self).__init__()
  124. self.a = a
  125. self.mode = mode
  126. self.nonlinearity = nonlinearity
  127. def _initialize(self, arr):
  128. tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
  129. _assignment(arr, tmp)
  130. def default_recurisive_init(custom_cell):
  131. """Initialize parameter."""
  132. for _, cell in custom_cell.cells_and_names():
  133. if isinstance(cell, nn.Conv2d):
  134. cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
  135. cell.weight.shape,
  136. cell.weight.dtype))
  137. if cell.bias is not None:
  138. fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
  139. bound = 1 / math.sqrt(fan_in)
  140. cell.bias.set_data(init.initializer(init.Uniform(bound),
  141. cell.bias.shape,
  142. cell.bias.dtype))
  143. elif isinstance(cell, nn.Dense):
  144. cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
  145. cell.weight.shape,
  146. cell.weight.dtype))
  147. if cell.bias is not None:
  148. fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
  149. bound = 1 / math.sqrt(fan_in)
  150. cell.bias.set_data(init.initializer(init.Uniform(bound),
  151. cell.bias.shape,
  152. cell.bias.dtype))
  153. elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
  154. pass
  155. def load_yolov4_params(args, network):
  156. """Load yolov4 cspdarknet parameter from checkpoint."""
  157. if args.pretrained_backbone:
  158. network = load_backbone(network, args.pretrained_backbone, args)
  159. args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
  160. else:
  161. args.logger.info('Not load pre-trained backbone, please be careful')
  162. if args.resume_yolov4:
  163. param_dict = load_checkpoint(args.resume_yolov4)
  164. param_dict_new = {}
  165. for key, values in param_dict.items():
  166. if key.startswith('moments.'):
  167. continue
  168. elif key.startswith('yolo_network.'):
  169. param_dict_new[key[13:]] = values
  170. args.logger.info('in resume {}'.format(key))
  171. else:
  172. param_dict_new[key] = values
  173. args.logger.info('in resume {}'.format(key))
  174. args.logger.info('resume finished')
  175. load_param_into_net(network, param_dict_new)
  176. args.logger.info('load_model {} success'.format(args.resume_yolov4))
  177. if args.filter_weight:
  178. if args.pretrained_checkpoint:
  179. param_dict = load_checkpoint(args.pretrained_checkpoint)
  180. for key in list(param_dict.keys()):
  181. if key in args.checkpoint_filter_list:
  182. args.logger.info('filter {}'.format(key))
  183. del param_dict[key]
  184. load_param_into_net(network, param_dict)
  185. args.logger.info('load_model {} success'.format(args.pretrained_checkpoint))
  186. else:
  187. args.logger.warning('Set filter_weight, but not load pretrained_checkpoint, please be careful')