You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

less_batch_normalization.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """less Batch Normalization"""
  16. import numpy as np
  17. from mindspore.nn.cell import Cell
  18. from mindspore.nn.layer import Dense
  19. from mindspore.ops import operations as P
  20. from mindspore.common import Tensor, Parameter
  21. from mindspore.common import dtype as mstype
  22. from mindspore.common.initializer import initializer
  23. __all__ = ["CommonHeadLastFN", "LessBN"]
  24. class CommonHeadLastFN(Cell):
  25. r"""
  26. The last full Normalization layer.
  27. This layer implements the operation as:
  28. .. math::
  29. \text{inputs} = \text{norm}(\text{inputs})
  30. \text{kernel} = \text{norm}(\text{kernel})
  31. \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
  32. Args:
  33. in_channels (int): The number of channels in the input space.
  34. out_channels (int): The number of channels in the output space.
  35. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  36. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  37. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  38. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  39. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  40. Supported Platforms:
  41. ``Ascend`` ``GPU`` ``CPU``
  42. Examples:
  43. >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
  44. >>> net = CommonHeadLastFN(3, 4)
  45. >>> output = net(input)
  46. """
  47. def __init__(self,
  48. in_channels,
  49. out_channels,
  50. weight_init='normal',
  51. bias_init='zeros',
  52. has_bias=True):
  53. super(CommonHeadLastFN, self).__init__()
  54. weight_shape = [out_channels, in_channels]
  55. self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
  56. self.x_norm = P.L2Normalize(axis=1)
  57. self.w_norm = P.L2Normalize(axis=1)
  58. self.fc = P.MatMul(transpose_a=False, transpose_b=True)
  59. self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
  60. self.has_bias = has_bias
  61. if self.has_bias:
  62. bias_shape = [out_channels]
  63. self.bias_add = P.BiasAdd()
  64. self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
  65. def construct(self, x):
  66. x = self.x_norm(x)
  67. w = self.w_norm(self.weight)
  68. x = self.fc(x, w)
  69. if self.has_bias:
  70. x = self.bias_add(x, self.bias)
  71. x = self.multiplier * x
  72. return x
  73. class LessBN(Cell):
  74. """
  75. Reduce the number of BN automatically to improve the network performance
  76. and ensure the network accuracy.
  77. Args:
  78. network (Cell): Network to be modified.
  79. fn_flag (bool): Replace FC with FN. default: False.
  80. Examples:
  81. >>> network = boost.LessBN(network)
  82. """
  83. def __init__(self, network, fn_flag=False):
  84. super(LessBN, self).__init__()
  85. self.network = network
  86. self.network.set_boost("less_bn")
  87. self.network.update_cell_prefix()
  88. if fn_flag:
  89. self._convert_to_less_bn_net(self.network)
  90. self.network.add_flags(defer_inline=True)
  91. def _convert_dense(self, subcell):
  92. """
  93. convert dense cell to FN cell
  94. """
  95. prefix = subcell.param_prefix
  96. new_subcell = CommonHeadLastFN(subcell.in_channels,
  97. subcell.out_channels,
  98. subcell.weight,
  99. subcell.bias,
  100. False)
  101. new_subcell.update_parameters_name(prefix + '.')
  102. return new_subcell
  103. def _convert_to_less_bn_net(self, net):
  104. """
  105. convert network to less_bn network
  106. """
  107. cells = net.name_cells()
  108. dense_name = []
  109. dense_list = []
  110. for name in cells:
  111. subcell = cells[name]
  112. if subcell == net:
  113. continue
  114. elif isinstance(subcell, (Dense)):
  115. dense_name.append(name)
  116. dense_list.append(subcell)
  117. else:
  118. self._convert_to_less_bn_net(subcell)
  119. if dense_list:
  120. new_subcell = self._convert_dense(dense_list[-1])
  121. net.insert_child_to_cell(dense_name[-1], new_subcell)
  122. def construct(self, *inputs):
  123. return self.network(*inputs)