You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_utils.py 6.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization utils."""
  16. import numpy as np
  17. def cal_quantization_params(input_min,
  18. input_max,
  19. data_type,
  20. num_bits=8,
  21. symmetric=False,
  22. narrow_range=False):
  23. r"""
  24. Calculate quantization params for scale and zero point.
  25. Args:
  26. input_min (numpy.ndarray): The dimension of channel or 1.
  27. input_max (numpy.ndarray): The dimension of channel or 1.
  28. data_type (numpy type) : Can ben numpy int8, numpy uint8.
  29. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  30. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  31. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  32. Returns:
  33. scale (numpy.ndarray): quantization param.
  34. zero point (numpy.ndarray): quantization param.
  35. """
  36. input_max = np.maximum(0.0, input_max)
  37. input_min = np.minimum(0.0, input_min)
  38. if input_min.shape != input_max.shape:
  39. raise ValueError("input min shape should equal to input max.")
  40. if len(input_min.shape) > 1:
  41. raise ValueError("input min and max shape should be one dim.")
  42. if input_min > input_max:
  43. raise ValueError("input_min min should less than input max.")
  44. if (input_max == input_min).all():
  45. # scale = 1.0, zp = 0.0
  46. return np.ones(input_min.shape), np.zeros(input_min.shape)
  47. if data_type == np.int8:
  48. quant_min = 0 - 2 ** (num_bits - 1)
  49. quant_max = 2 ** (num_bits - 1)
  50. else:
  51. quant_min = 0
  52. quant_max = 2 ** num_bits - 1
  53. if narrow_range:
  54. quant_min = quant_min + 1
  55. # calculate scale
  56. if symmetric:
  57. input_max = np.maximum(-input_min, input_max)
  58. input_min = -input_max
  59. scale = (input_max - input_min) / (quant_max - quant_min)
  60. # calculate zero point
  61. if symmetric:
  62. zp = np.zeros(input_min.shape)
  63. else:
  64. zp_from_min = quant_min - input_min / scale
  65. zp_from_max = quant_max - input_max / scale
  66. zp_from_min_error = np.abs(quant_min) + np.abs(input_min / scale)
  67. zp_from_max_error = np.abs(quant_max) + np.abs(input_max / scale)
  68. zp_double = zp_from_min if zp_from_min_error < zp_from_max_error else zp_from_max
  69. if zp_double < quant_min:
  70. zp = quant_min
  71. elif zp_double > quant_max:
  72. zp = quant_max
  73. else:
  74. zp = np.floor(zp_double + 0.5)
  75. return scale, zp
  76. def weight2int(data,
  77. scale,
  78. zero_point):
  79. r"""
  80. Calculate int8/uint8 weight from fp32. the formula is defined as:
  81. .. math::
  82. int8/uint8 = round(float/scale) + offset
  83. Args:
  84. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
  85. scale (numpy.ndarray): The dimension of channel or 1.
  86. zero_point (numpy.ndarray): The dimension of channel or 1.
  87. Returns:
  88. weight (numpy.ndarray): The dimension of channel or 1.
  89. """
  90. if scale.shape != zero_point.shape:
  91. raise ValueError("scale and zero_point should have the same shape.")
  92. if scale.shape[0] > 0:
  93. scale = scale.reshape(1, -1)
  94. zero_point = zero_point.reshape(1, -1)
  95. return np.round((data/scale) + zero_point)
  96. def scale_zp_from_fack_quant_cell(cell, data_type):
  97. r"""
  98. Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
  99. Args:
  100. cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
  101. data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
  102. Returns:
  103. scale (numpy.ndarray): quantization param.
  104. zero point (numpy.ndarray): quantization param.
  105. """
  106. minq = cell.minq.data.asnumpy()
  107. maxq = cell.maxq.data.asnumpy()
  108. op = cell.fake_quant_infer
  109. scale, zp = cal_quantization_params(
  110. minq, maxq, data_type,
  111. num_bits=op.num_bits,
  112. symmetric=op.symmetric,
  113. narrow_range=op.narrow_range)
  114. return scale, zp
  115. def scale_zp_from_data(op, minq, maxq, data_type):
  116. r"""
  117. Get calculate quantization params for scale and zero point.
  118. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  119. Args:
  120. op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
  121. `mindspore.ops.operation.FakeQuantPerChannel`
  122. minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  123. maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  124. data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
  125. Returns:
  126. scale (numpy.ndarray): quantization param.
  127. zero point (numpy.ndarray): quantization param.
  128. """
  129. minq = minq.data.asnumpy()
  130. maxq = maxq.data.asnumpy()
  131. scale, zp = cal_quantization_params(
  132. minq, maxq, data_type,
  133. num_bits=op.num_bits,
  134. symmetric=op.symmetric,
  135. narrow_range=op.narrow_range)
  136. return scale, zp
  137. def fold_batchnorm(weight, cell_quant):
  138. r"""
  139. Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
  140. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  141. Args:
  142. weight (numpy.ndarray): Weight of `cell_quant`.
  143. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
  144. Returns:
  145. weight (numpy.ndarray): Folded weight.
  146. bias (numpy.ndarray): Folded bias.
  147. """
  148. variance = cell_quant.moving_variance.data.asnumpy()
  149. mean = cell_quant.moving_mean.data.asnumpy()
  150. gamma = cell_quant.gamma.data.asnumpy()
  151. beta = cell_quant.beta.data.asnumpy()
  152. epsilon = cell_quant.eps
  153. sigma = np.sqrt(variance + epsilon)
  154. gamma = gamma.reshape(-1, 1, 1, 1)
  155. sigma = sigma.reshape(-1, 1, 1, 1)
  156. mean = mean.reshape(-1, 1, 1, 1)
  157. weight = weight * gamma / sigma
  158. bias = beta - gamma * mean / sigma
  159. return weight, bias