You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_utils.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization utils."""
  16. import numpy as np
  17. __all__ = ["load_nonquant_param_into_quant_net"]
  18. def cal_quantization_params(input_min,
  19. input_max,
  20. data_type,
  21. num_bits=8,
  22. symmetric=False,
  23. narrow_range=False):
  24. r"""
  25. Calculate quantization params for scale and zero point.
  26. Args:
  27. input_min (numpy.ndarray): The dimension of channel or 1.
  28. input_max (numpy.ndarray): The dimension of channel or 1.
  29. data_type (numpy type) : Can be numpy int8, numpy uint8.
  30. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  31. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  32. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  33. Returns:
  34. scale (numpy.ndarray): quantization param.
  35. zero point (numpy.ndarray): quantization param.
  36. """
  37. input_max = np.maximum(0.0, input_max)
  38. input_min = np.minimum(0.0, input_min)
  39. if input_min.shape != input_max.shape:
  40. raise ValueError("input min shape should equal to input max.")
  41. if len(input_min.shape) > 1:
  42. raise ValueError("input min and max shape should be one dim.")
  43. if (input_min > input_max).all():
  44. raise ValueError("input_min min should less than input max.")
  45. if (input_max == input_min).all():
  46. return np.ones(input_min.shape), np.zeros(input_min.shape)
  47. if data_type == np.int8:
  48. quant_min = 0 - 2 ** (num_bits - 1)
  49. quant_max = 2 ** (num_bits - 1) - 1
  50. elif data_type == np.uint8:
  51. quant_min = 0
  52. quant_max = 2 ** num_bits - 1
  53. else:
  54. raise ValueError("Unsupported datatype({})".format(data_type))
  55. if narrow_range:
  56. quant_min = quant_min + 1
  57. # calculate scale
  58. if symmetric:
  59. input_max = np.maximum(-input_min, input_max)
  60. input_min = -input_max
  61. scale = (input_max - input_min) / (quant_max - quant_min)
  62. # calculate zero point
  63. if symmetric:
  64. zp = np.zeros(input_min.shape)
  65. else:
  66. zp_double = quant_min - input_min / scale
  67. zp = np.floor(zp_double + 0.5)
  68. return scale, zp
  69. def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
  70. r"""
  71. Calculate int8/uint8 weight from fp32. the formula is defined as:
  72. .. math::
  73. int8/uint8 = round(float/scale) + offset
  74. Args:
  75. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
  76. scale (numpy.ndarray): The dimension of channel or 1.
  77. zero_point (numpy.ndarray): The dimension of channel or 1.
  78. data_type (numpy type) : Can be numpy int8, numpy uint8.
  79. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  80. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  81. Returns:
  82. weight (numpy.ndarray): The dimension of channel or 1.
  83. """
  84. if scale.shape != zero_point.shape:
  85. raise ValueError("`scale` and `zero_point` should have the same shape.")
  86. if scale.shape[0] < 0:
  87. raise ValueError("`scale` and `zero_point` shape should greater than zero.")
  88. if len(scale.shape) >= 1 and scale.shape[0] > 1:
  89. # for perchannel
  90. if scale.shape[0] == data.shape[0]:
  91. # `Conv2d` or `Dense` op weight
  92. shape_list = [-1] + [1] * len(data.shape[1:])
  93. scale = scale.reshape(shape_list)
  94. zero_point = zero_point.reshape(shape_list)
  95. elif scale.shape[0] == data.shape[1]:
  96. # `DepthwiseConv2d` op weight
  97. shape_list = [1, -1] + [1] * len(data.shape[2:])
  98. scale = scale.reshape(shape_list)
  99. zero_point = zero_point.reshape(shape_list)
  100. else:
  101. raise ValueError("Unsupported weight shape({})".format(data.shape))
  102. if data_type == np.int8:
  103. quant_min = 0 - 2 ** (num_bits - 1)
  104. quant_max = 2 ** (num_bits - 1) - 1
  105. elif data_type == np.uint8:
  106. quant_min = 0
  107. quant_max = 2 ** num_bits - 1
  108. else:
  109. raise ValueError("Unsupported weight datatype({})".format(data_type))
  110. if narrow_range:
  111. quant_min = quant_min + 1
  112. weight_int = np.round((data / scale) + zero_point)
  113. weight_int[weight_int > quant_max] = quant_max
  114. weight_int[weight_int < quant_min] = quant_min
  115. return weight_int
  116. def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
  117. """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
  118. minq = cell.minq.data.asnumpy()
  119. maxq = cell.maxq.data.asnumpy()
  120. op = cell.fake_quant_infer
  121. scale, zp = cal_quantization_params(
  122. minq, maxq, data_type,
  123. num_bits=op.num_bits,
  124. symmetric=op.symmetric,
  125. narrow_range=op.narrow_range)
  126. return scale, zp, maxq, minq
  127. def scale_zp_from_data(op, minq, maxq, data_type):
  128. r"""
  129. Get calculate quantization params for scale and zero point.
  130. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  131. Args:
  132. op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
  133. `mindspore.ops.operation.FakeQuantPerChannel`
  134. minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  135. maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  136. data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.
  137. Returns:
  138. scale (numpy.ndarray): quantization param.
  139. zero point (numpy.ndarray): quantization param.
  140. """
  141. minq = minq.data.asnumpy()
  142. maxq = maxq.data.asnumpy()
  143. scale, zp = cal_quantization_params(
  144. minq, maxq, data_type,
  145. num_bits=op.num_bits,
  146. symmetric=op.symmetric,
  147. narrow_range=op.narrow_range)
  148. return scale, zp
  149. def scale_zp_max_min_from_data(op, minq, maxq, data_type):
  150. """Get calculate quantization params for scale, zero point, max and min."""
  151. minq = minq.data.asnumpy()
  152. maxq = maxq.data.asnumpy()
  153. scale, zp = cal_quantization_params(
  154. minq, maxq, data_type,
  155. num_bits=op.num_bits,
  156. symmetric=op.symmetric,
  157. narrow_range=op.narrow_range)
  158. return scale, zp, maxq, minq
  159. def fold_batchnorm(weight, cell_quant):
  160. r"""
  161. Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
  162. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  163. Args:
  164. weight (numpy.ndarray): Weight of `cell_quant`.
  165. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
  166. Returns:
  167. weight (numpy.ndarray): Folded weight.
  168. bias (numpy.ndarray): Folded bias.
  169. """
  170. variance = cell_quant.moving_variance.data.asnumpy()
  171. mean = cell_quant.moving_mean.data.asnumpy()
  172. gamma = cell_quant.gamma.data.asnumpy()
  173. beta = cell_quant.beta.data.asnumpy()
  174. epsilon = cell_quant.eps
  175. sigma = np.sqrt(variance + epsilon)
  176. if gamma.shape[0] == weight.shape[0]:
  177. # `Conv2d` or `Dense` op weight
  178. shape_list = [-1] + [1] * len(weight.shape[1:])
  179. _gamma = gamma.reshape(shape_list)
  180. _sigma = sigma.reshape(shape_list)
  181. elif gamma.shape[0] == weight.shape[1]:
  182. # `DepthwiseConv2d` op weight
  183. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  184. _gamma = gamma.reshape(shape_list)
  185. _sigma = sigma.reshape(shape_list)
  186. else:
  187. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  188. weight = weight * _gamma / _sigma
  189. bias = beta - gamma * mean / sigma
  190. return weight, bias
  191. def without_fold_batchnorm(weight, cell_quant):
  192. r"""
  193. Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
  194. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  195. Args:
  196. weight (numpy.ndarray): Weight of `cell_quant`.
  197. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
  198. Returns:
  199. weight (numpy.ndarray): whihout folded weight.
  200. bias (numpy.ndarray): without folded bias.
  201. """
  202. variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
  203. mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
  204. gamma = cell_quant.batchnorm.gamma.data.asnumpy()
  205. beta = cell_quant.batchnorm.beta.data.asnumpy()
  206. epsilon = cell_quant.batchnorm.eps
  207. sigma = np.sqrt(variance + epsilon)
  208. if gamma.shape[0] == weight.shape[0]:
  209. # `Conv2d` or `Dense` op weight
  210. shape_list = [-1] + [1] * len(weight.shape[1:])
  211. _gamma = gamma.reshape(shape_list)
  212. _sigma = sigma.reshape(shape_list)
  213. elif gamma.shape[0] == weight.shape[1]:
  214. # `DepthwiseConv2d` op weight
  215. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  216. _gamma = gamma.reshape(shape_list)
  217. _sigma = sigma.reshape(shape_list)
  218. else:
  219. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  220. weight = weight * _gamma / _sigma
  221. bias = beta - gamma * mean / sigma
  222. return weight, bias
  223. def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
  224. r"""
  225. Load fp32 model parameters into quantization model.
  226. Args:
  227. quant_model: quantization model.
  228. params_dict: parameter dict that stores fp32 parameters.
  229. quant_new_params: parameters that exist in quantitative network but not in unquantitative network.
  230. Returns:
  231. None
  232. """
  233. iterable_dict = {
  234. 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
  235. 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
  236. 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
  237. 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
  238. 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
  239. 'moving_variance': iter(
  240. [item for item in params_dict.items() if item[0].endswith('moving_variance')]),
  241. 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
  242. 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
  243. }
  244. for name, param in quant_model.parameters_and_names():
  245. key_name = name.split(".")[-1]
  246. if key_name not in iterable_dict.keys():
  247. if quant_new_params is not None and key_name in quant_new_params:
  248. continue
  249. raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
  250. value_param = next(iterable_dict[key_name], None)
  251. if value_param is not None:
  252. param.set_data(value_param[1].data)
  253. print(f'init model param {name} with checkpoint param {value_param[0]}')