You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_utils.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization utils."""
  16. import numpy as np
  17. def cal_quantization_params(input_min,
  18. input_max,
  19. data_type,
  20. num_bits=8,
  21. symmetric=False,
  22. narrow_range=False):
  23. r"""
  24. Calculate quantization params for scale and zero point.
  25. Args:
  26. input_min (numpy.ndarray): The dimension of channel or 1.
  27. input_max (numpy.ndarray): The dimension of channel or 1.
  28. data_type (numpy type) : Can be numpy int8, numpy uint8.
  29. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  30. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  31. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  32. Returns:
  33. scale (numpy.ndarray): quantization param.
  34. zero point (numpy.ndarray): quantization param.
  35. """
  36. input_max = np.maximum(0.0, input_max)
  37. input_min = np.minimum(0.0, input_min)
  38. if input_min.shape != input_max.shape:
  39. raise ValueError("input min shape should equal to input max.")
  40. if len(input_min.shape) > 1:
  41. raise ValueError("input min and max shape should be one dim.")
  42. if (input_min > input_max).all():
  43. raise ValueError("input_min min should less than input max.")
  44. if (input_max == input_min).all():
  45. return np.ones(input_min.shape), np.zeros(input_min.shape)
  46. if data_type == np.int8:
  47. quant_min = 0 - 2 ** (num_bits - 1)
  48. quant_max = 2 ** (num_bits - 1) - 1
  49. elif data_type == np.uint8:
  50. quant_min = 0
  51. quant_max = 2 ** num_bits - 1
  52. else:
  53. raise ValueError("Unsupported datatype({})".format(data_type))
  54. if narrow_range:
  55. quant_min = quant_min + 1
  56. # calculate scale
  57. if symmetric:
  58. input_max = np.maximum(-input_min, input_max)
  59. input_min = -input_max
  60. scale = (input_max - input_min) / (quant_max - quant_min)
  61. # calculate zero point
  62. if symmetric:
  63. zp = np.zeros(input_min.shape)
  64. else:
  65. zp_double = quant_min - input_min / scale
  66. zp = np.floor(zp_double + 0.5)
  67. return scale, zp
  68. def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
  69. r"""
  70. Calculate int8/uint8 weight from fp32. the formula is defined as:
  71. .. math::
  72. int8/uint8 = round(float/scale) + offset
  73. Args:
  74. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
  75. scale (numpy.ndarray): The dimension of channel or 1.
  76. zero_point (numpy.ndarray): The dimension of channel or 1.
  77. data_type (numpy type) : Can be numpy int8, numpy uint8.
  78. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  79. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  80. Returns:
  81. weight (numpy.ndarray): The dimension of channel or 1.
  82. """
  83. if scale.shape != zero_point.shape:
  84. raise ValueError("`scale` and `zero_point` should have the same shape.")
  85. if scale.shape[0] < 0:
  86. raise ValueError("`scale` and `zero_point` shape should greater than zero.")
  87. if len(scale.shape) >= 1 and scale.shape[0] > 1:
  88. # for perchannel
  89. if scale.shape[0] == data.shape[0]:
  90. # `Conv2d` or `Dense` op weight
  91. shape_list = [-1] + [1] * len(data.shape[1:])
  92. scale = scale.reshape(shape_list)
  93. zero_point = zero_point.reshape(shape_list)
  94. elif scale.shape[0] == data.shape[1]:
  95. # `DepthwiseConv2d` op weight
  96. shape_list = [1, -1] + [1] * len(data.shape[2:])
  97. scale = scale.reshape(shape_list)
  98. zero_point = zero_point.reshape(shape_list)
  99. else:
  100. raise ValueError("Unsupported weight shape({})".format(data.shape))
  101. if data_type == np.int8:
  102. quant_min = 0 - 2 ** (num_bits - 1)
  103. quant_max = 2 ** (num_bits - 1) - 1
  104. elif data_type == np.uint8:
  105. quant_min = 0
  106. quant_max = 2 ** num_bits - 1
  107. else:
  108. raise ValueError("Unsupported weight datatype({})".format(data_type))
  109. if narrow_range:
  110. quant_min = quant_min + 1
  111. weight_int = np.round((data / scale) + zero_point)
  112. weight_int[weight_int > quant_max] = quant_max
  113. weight_int[weight_int < quant_min] = quant_min
  114. return weight_int
  115. def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
  116. """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
  117. minq = cell.minq.data.asnumpy()
  118. maxq = cell.maxq.data.asnumpy()
  119. op = cell.fake_quant_infer
  120. scale, zp = cal_quantization_params(
  121. minq, maxq, data_type,
  122. num_bits=op.num_bits,
  123. symmetric=op.symmetric,
  124. narrow_range=op.narrow_range)
  125. return scale, zp, maxq, minq
  126. def scale_zp_from_data(op, minq, maxq, data_type):
  127. r"""
  128. Get calculate quantization params for scale and zero point.
  129. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  130. Args:
  131. op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
  132. `mindspore.ops.operation.FakeQuantPerChannel`
  133. minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  134. maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
  135. data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.
  136. Returns:
  137. scale (numpy.ndarray): quantization param.
  138. zero point (numpy.ndarray): quantization param.
  139. """
  140. minq = minq.data.asnumpy()
  141. maxq = maxq.data.asnumpy()
  142. scale, zp = cal_quantization_params(
  143. minq, maxq, data_type,
  144. num_bits=op.num_bits,
  145. symmetric=op.symmetric,
  146. narrow_range=op.narrow_range)
  147. return scale, zp
  148. def scale_zp_max_min_from_data(op, minq, maxq, data_type):
  149. """Get calculate quantization params for scale, zero point, max and min."""
  150. minq = minq.data.asnumpy()
  151. maxq = maxq.data.asnumpy()
  152. scale, zp = cal_quantization_params(
  153. minq, maxq, data_type,
  154. num_bits=op.num_bits,
  155. symmetric=op.symmetric,
  156. narrow_range=op.narrow_range)
  157. return scale, zp, maxq, minq
  158. def fold_batchnorm(weight, cell_quant):
  159. r"""
  160. Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
  161. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  162. Args:
  163. weight (numpy.ndarray): Weight of `cell_quant`.
  164. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
  165. Returns:
  166. weight (numpy.ndarray): Folded weight.
  167. bias (numpy.ndarray): Folded bias.
  168. """
  169. variance = cell_quant.moving_variance.data.asnumpy()
  170. mean = cell_quant.moving_mean.data.asnumpy()
  171. gamma = cell_quant.gamma.data.asnumpy()
  172. beta = cell_quant.beta.data.asnumpy()
  173. epsilon = cell_quant.eps
  174. sigma = np.sqrt(variance + epsilon)
  175. if gamma.shape[0] == weight.shape[0]:
  176. # `Conv2d` or `Dense` op weight
  177. shape_list = [-1] + [1] * len(weight.shape[1:])
  178. _gamma = gamma.reshape(shape_list)
  179. _sigma = sigma.reshape(shape_list)
  180. elif gamma.shape[0] == weight.shape[1]:
  181. # `DepthwiseConv2d` op weight
  182. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  183. _gamma = gamma.reshape(shape_list)
  184. _sigma = sigma.reshape(shape_list)
  185. else:
  186. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  187. weight = weight * _gamma / _sigma
  188. bias = beta - gamma * mean / sigma
  189. return weight, bias
  190. def without_fold_batchnorm(weight, cell_quant):
  191. r"""
  192. Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
  193. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  194. Args:
  195. weight (numpy.ndarray): Weight of `cell_quant`.
  196. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
  197. Returns:
  198. weight (numpy.ndarray): whihout folded weight.
  199. bias (numpy.ndarray): without folded bias.
  200. """
  201. variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
  202. mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
  203. gamma = cell_quant.batchnorm.gamma.data.asnumpy()
  204. beta = cell_quant.batchnorm.beta.data.asnumpy()
  205. epsilon = cell_quant.batchnorm.eps
  206. sigma = np.sqrt(variance + epsilon)
  207. if gamma.shape[0] == weight.shape[0]:
  208. # `Conv2d` or `Dense` op weight
  209. shape_list = [-1] + [1] * len(weight.shape[1:])
  210. _gamma = gamma.reshape(shape_list)
  211. _sigma = sigma.reshape(shape_list)
  212. elif gamma.shape[0] == weight.shape[1]:
  213. # `DepthwiseConv2d` op weight
  214. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  215. _gamma = gamma.reshape(shape_list)
  216. _sigma = sigma.reshape(shape_list)
  217. else:
  218. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  219. weight = weight * _gamma / _sigma
  220. bias = beta - gamma * mean / sigma
  221. return weight, bias
  222. def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
  223. """
  224. load fp32 model parameters to quantization model.
  225. Args:
  226. quant_model: quantization model.
  227. params_dict: f32 param.
  228. quant_new_params:parameters that exist in quantative network but not in unquantative network.
  229. Returns:
  230. None
  231. """
  232. iterable_dict = {
  233. 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
  234. 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
  235. 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
  236. 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
  237. 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
  238. 'moving_variance': iter(
  239. [item for item in params_dict.items() if item[0].endswith('moving_variance')]),
  240. 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
  241. 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
  242. }
  243. for name, param in quant_model.parameters_and_names():
  244. key_name = name.split(".")[-1]
  245. if key_name not in iterable_dict.keys():
  246. if quant_new_params is not None and key_name in quant_new_params:
  247. continue
  248. raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
  249. value_param = next(iterable_dict[key_name], None)
  250. if value_param is not None:
  251. param.set_data(value_param[1].data)
  252. print(f'init model param {name} with checkpoint param {value_param[0]}')