You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_utils.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization utils."""
  16. import numpy as np
  17. def cal_quantization_params(input_min,
  18. input_max,
  19. data_type,
  20. num_bits=8,
  21. symmetric=False,
  22. narrow_range=False):
  23. r"""
  24. Calculate quantization params for scale and zero point.
  25. Args:
  26. input_min (numpy.ndarray): The dimension of channel or 1.
  27. input_max (numpy.ndarray): The dimension of channel or 1.
  28. data_type (numpy type) : Can ben numpy int8, numpy uint8.
  29. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  30. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  31. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  32. Returns:
  33. scale (numpy.ndarray): quantization param.
  34. zero point (numpy.ndarray): quantization param.
  35. """
  36. input_max = np.maximum(0.0, input_max)
  37. input_min = np.minimum(0.0, input_min)
  38. if input_min.shape != input_max.shape:
  39. raise ValueError("input min shape should equal to input max.")
  40. if len(input_min.shape) > 1:
  41. raise ValueError("input min and max shape should be one dim.")
  42. if (input_min > input_max).all():
  43. raise ValueError("input_min min should less than input max.")
  44. if (input_max == input_min).all():
  45. # scale = 1.0, zp = 0.0
  46. return np.ones(input_min.shape), np.zeros(input_min.shape)
  47. if data_type == np.int8:
  48. quant_min = 0 - 2 ** (num_bits - 1)
  49. quant_max = 2 ** (num_bits - 1)
  50. else:
  51. quant_min = 0
  52. quant_max = 2 ** num_bits - 1
  53. if narrow_range:
  54. quant_min = quant_min + 1
  55. # calculate scale
  56. if symmetric:
  57. input_max = np.maximum(-input_min, input_max)
  58. input_min = -input_max
  59. scale = (input_max - input_min) / (quant_max - quant_min)
  60. # calculate zero point
  61. if symmetric:
  62. zp = np.zeros(input_min.shape)
  63. else:
  64. zp_from_min = quant_min - input_min / scale
  65. zp_from_max = quant_max - input_max / scale
  66. zp_from_min_error = np.abs(quant_min) + np.abs(input_min / scale)
  67. zp_from_max_error = np.abs(quant_max) + np.abs(input_max / scale)
  68. zp_double = zp_from_min if zp_from_min_error < zp_from_max_error else zp_from_max
  69. if zp_double < quant_min:
  70. zp = quant_min
  71. elif zp_double > quant_max:
  72. zp = quant_max
  73. else:
  74. zp = np.floor(zp_double + 0.5)
  75. return scale, zp
  76. def weight2int(data, scale, zero_point):
  77. r"""
  78. Calculate int8/uint8 weight from fp32. the formula is defined as:
  79. .. math::
  80. int8/uint8 = round(float/scale) + offset
  81. Args:
  82. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
  83. scale (numpy.ndarray): The dimension of channel or 1.
  84. zero_point (numpy.ndarray): The dimension of channel or 1.
  85. Returns:
  86. weight (numpy.ndarray): The dimension of channel or 1.
  87. """
  88. if scale.shape != zero_point.shape:
  89. raise ValueError("`scale` and `zero_point` should have the same shape.")
  90. if scale.shape[0] < 0:
  91. raise ValueError("`scale` and `zero_point` shape should greater than zero.")
  92. if len(scale.shape) >= 1 and scale.shape[0] > 1:
  93. # for perchannel
  94. if scale.shape[0] == data.shape[0]:
  95. # `Conv2d` or `Dense` op weight
  96. shape_list = [-1] + [1] * len(data.shape[1:])
  97. scale = scale.reshape(shape_list)
  98. zero_point = zero_point.reshape(shape_list)
  99. elif scale.shape[0] == data.shape[1]:
  100. # `DepthwiseConv2d` op weight
  101. shape_list = [1, -1] + [1] * len(data.shape[2:])
  102. scale = scale.reshape(shape_list)
  103. zero_point = zero_point.reshape(shape_list)
  104. else:
  105. raise ValueError("Unsupported weight shape({})".format(data.shape))
  106. return np.round((data / scale) + zero_point)
  107. def scale_zp_from_fack_quant_cell(cell, data_type):
  108. r"""
  109. Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
  110. Args:
  111. cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
  112. data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
  113. Returns:
  114. scale (numpy.ndarray): quantization param.
  115. zero point (numpy.ndarray): quantization param.
  116. """
  117. minq = cell.minq.data.asnumpy()
  118. maxq = cell.maxq.data.asnumpy()
  119. op = cell.fake_quant_infer
  120. scale, zp = cal_quantization_params(
  121. minq, maxq, data_type,
  122. num_bits=op.num_bits,
  123. symmetric=op.symmetric,
  124. narrow_range=op.narrow_range)
  125. return scale, zp
  126. def scale_zp_from_data(op, minq, maxq, data_type):
  127. r"""
  128. Get calculate quantization params for scale and zero point.
  129. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  130. Args:
  131. op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
  132. `mindspore.ops.operation.FakeQuantPerChannel`
  133. minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  134. maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  135. data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
  136. Returns:
  137. scale (numpy.ndarray): quantization param.
  138. zero point (numpy.ndarray): quantization param.
  139. """
  140. minq = minq.data.asnumpy()
  141. maxq = maxq.data.asnumpy()
  142. scale, zp = cal_quantization_params(
  143. minq, maxq, data_type,
  144. num_bits=op.num_bits,
  145. symmetric=op.symmetric,
  146. narrow_range=op.narrow_range)
  147. return scale, zp
  148. def fold_batchnorm(weight, cell_quant):
  149. r"""
  150. Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
  151. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  152. Args:
  153. weight (numpy.ndarray): Weight of `cell_quant`.
  154. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
  155. Returns:
  156. weight (numpy.ndarray): Folded weight.
  157. bias (numpy.ndarray): Folded bias.
  158. """
  159. variance = cell_quant.moving_variance.data.asnumpy()
  160. mean = cell_quant.moving_mean.data.asnumpy()
  161. gamma = cell_quant.gamma.data.asnumpy()
  162. beta = cell_quant.beta.data.asnumpy()
  163. epsilon = cell_quant.eps
  164. sigma = np.sqrt(variance + epsilon)
  165. if gamma.shape[0] == weight.shape[0]:
  166. # `Conv2d` or `Dense` op weight
  167. shape_list = [-1] + [1] * len(weight.shape[1:])
  168. _gamma = gamma.reshape(shape_list)
  169. _sigma = sigma.reshape(shape_list)
  170. elif gamma.shape[0] == weight.shape[1]:
  171. # `DepthwiseConv2d` op weight
  172. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  173. _gamma = gamma.reshape(shape_list)
  174. _sigma = sigma.reshape(shape_list)
  175. else:
  176. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  177. weight = weight * _gamma / _sigma
  178. bias = beta - gamma * mean / sigma
  179. return weight, bias
  180. def without_fold_batchnorm(weight, cell_quant):
  181. r"""
  182. Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
  183. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  184. Args:
  185. weight (numpy.ndarray): Weight of `cell_quant`.
  186. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
  187. Returns:
  188. weight (numpy.ndarray): whihout folded weight.
  189. bias (numpy.ndarray): without folded bias.
  190. """
  191. variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
  192. mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
  193. gamma = cell_quant.batchnorm.gamma.data.asnumpy()
  194. beta = cell_quant.batchnorm.beta.data.asnumpy()
  195. epsilon = cell_quant.batchnorm.eps
  196. sigma = np.sqrt(variance + epsilon)
  197. if gamma.shape[0] == weight.shape[0]:
  198. # `Conv2d` or `Dense` op weight
  199. shape_list = [-1] + [1] * len(weight.shape[1:])
  200. _gamma = gamma.reshape(shape_list)
  201. _sigma = sigma.reshape(shape_list)
  202. elif gamma.shape[0] == weight.shape[1]:
  203. # `DepthwiseConv2d` op weight
  204. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  205. _gamma = gamma.reshape(shape_list)
  206. _sigma = sigma.reshape(shape_list)
  207. else:
  208. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  209. weight = weight * _gamma / _sigma
  210. bias = beta - gamma * mean / sigma
  211. return weight, bias
  212. def load_nonquant_param_into_quant_net(quant_model, params_dict):
  213. """
  214. load fp32 model parameters to quantization model.
  215. Args:
  216. quant_model: quantization model
  217. params_dict: f32 param
  218. Returns:
  219. None
  220. """
  221. iterable_dict = {
  222. 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
  223. 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
  224. 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
  225. 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
  226. 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
  227. 'moving_variance': iter(
  228. [item for item in params_dict.items() if item[0].endswith('moving_variance')]),
  229. 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
  230. 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
  231. }
  232. for name, param in quant_model.parameters_and_names():
  233. key_name = name.split(".")[-1]
  234. if key_name not in iterable_dict.keys():
  235. raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
  236. value_param = next(iterable_dict[key_name], None)
  237. if value_param is not None:
  238. param.set_parameter_data(value_param[1].data)
  239. print(f'init model param {name} with checkpoint param {value_param[0]}')