You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_utils.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization utils."""
  16. import numpy as np
  17. from mindspore._checkparam import Validator
  18. from ... import nn
  19. __all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"]
  20. def cal_quantization_params(input_min,
  21. input_max,
  22. quant_min,
  23. quant_max,
  24. data_type,
  25. symmetric=False):
  26. r"""
  27. Calculate quantization params for scale and zero point.
  28. Args:
  29. input_min (numpy.ndarray): The dimension of channel or 1.
  30. input_max (numpy.ndarray): The dimension of channel or 1.
  31. quant_min (int): The minimum quantization integer.
  32. quant_max (int): The maximum quantization integer.
  33. data_type (numpy type) : Can be numpy int8, numpy uint8.
  34. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  35. Returns:
  36. scale (numpy.ndarray): quantization param.
  37. zero point (numpy.ndarray): quantization param.
  38. """
  39. input_max = np.maximum(0.0, input_max)
  40. input_min = np.minimum(0.0, input_min)
  41. if input_min.shape != input_max.shape:
  42. raise ValueError("input min shape should equal to input max.")
  43. if len(input_min.shape) > 1:
  44. raise ValueError("input min and max shape should be one dim.")
  45. if (input_min > input_max).all():
  46. raise ValueError("input_min min should less than input max.")
  47. if (input_max == input_min).all():
  48. return np.ones(input_min.shape), np.zeros(input_min.shape)
  49. # calculate scale
  50. if symmetric:
  51. input_max = np.maximum(-input_min, input_max)
  52. input_min = -input_max
  53. scale = (input_max - input_min) / (quant_max - quant_min)
  54. # calculate zero point
  55. if data_type == np.int8 and symmetric:
  56. zp = np.zeros(input_min.shape)
  57. else:
  58. zp_double = quant_min - input_min / scale
  59. zp = np.floor(zp_double + 0.5)
  60. return scale, zp
  61. def get_quant_min_max(data_type, num_bits=8, narrow_range=False):
  62. """Calculate quantization params for minimum/maximum quantization integer"""
  63. if data_type == np.int8:
  64. quant_min = 0 - 2 ** (num_bits - 1)
  65. quant_max = 2 ** (num_bits - 1) - 1
  66. elif data_type == np.uint8:
  67. quant_min = 0
  68. quant_max = 2 ** num_bits - 1
  69. else:
  70. raise ValueError("Unsupported datatype({})".format(data_type))
  71. if narrow_range:
  72. quant_min = quant_min + 1
  73. return quant_min, quant_max
  74. def weight2int(data, scale, zero_point, quant_min, quant_max):
  75. r"""
  76. Calculate int8/uint8 weight from fp32. the formula is defined as:
  77. .. math::
  78. int8/uint8 = round(float/scale) + offset
  79. Args:
  80. data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
  81. scale (numpy.ndarray): The dimension of channel or 1.
  82. zero_point (numpy.ndarray): The dimension of channel or 1.
  83. quant_min (int): The minimum quantization integer.
  84. quant_max (int): The maximum quantization integer.
  85. Returns:
  86. weight (numpy.ndarray): The dimension of channel or 1.
  87. """
  88. if scale.shape != zero_point.shape:
  89. raise ValueError("`scale` and `zero_point` should have the same shape.")
  90. if scale.shape[0] < 0:
  91. raise ValueError("`scale` and `zero_point` shape should greater than zero.")
  92. if len(scale.shape) >= 1 and scale.shape[0] > 1:
  93. # for perchannel
  94. if scale.shape[0] == data.shape[0]:
  95. # `Conv2d` or `Dense` op weight
  96. shape_list = [-1] + [1] * len(data.shape[1:])
  97. scale = scale.reshape(shape_list)
  98. zero_point = zero_point.reshape(shape_list)
  99. elif scale.shape[0] == data.shape[1]:
  100. # `DepthwiseConv2d` op weight
  101. shape_list = [1, -1] + [1] * len(data.shape[2:])
  102. scale = scale.reshape(shape_list)
  103. zero_point = zero_point.reshape(shape_list)
  104. else:
  105. raise ValueError("Unsupported weight shape({})".format(data.shape))
  106. weight_int = np.round((data / scale) + zero_point)
  107. weight_int[weight_int > quant_max] = quant_max
  108. weight_int[weight_int < quant_min] = quant_min
  109. return weight_int
  110. def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
  111. """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`."""
  112. minq = cell.minq.data.asnumpy()
  113. maxq = cell.maxq.data.asnumpy()
  114. # make sure maxq > 0 and minq <= 0
  115. if cell.mode == 'LEARNED_SCALE':
  116. maxq = np.abs(maxq)
  117. minq = -np.abs(minq)
  118. quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range)
  119. symmetric = cell.symmetric and not cell.neg_trunc
  120. scale, zp = cal_quantization_params(
  121. minq, maxq,
  122. quant_min, quant_max, data_type,
  123. symmetric=symmetric)
  124. return scale, zp, maxq, minq
  125. def fold_batchnorm(weight, cell_quant):
  126. r"""
  127. Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
  128. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  129. Args:
  130. weight (numpy.ndarray): Weight of `cell_quant`.
  131. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
  132. Returns:
  133. weight (numpy.ndarray): Folded weight.
  134. bias (numpy.ndarray): Folded bias.
  135. """
  136. variance = cell_quant.moving_variance.data.asnumpy()
  137. mean = cell_quant.moving_mean.data.asnumpy()
  138. gamma = cell_quant.gamma.data.asnumpy()
  139. beta = cell_quant.beta.data.asnumpy()
  140. epsilon = cell_quant.eps
  141. sigma = np.sqrt(variance + epsilon)
  142. if gamma.shape[0] == weight.shape[0]:
  143. # `Conv2d` or `Dense` op weight
  144. shape_list = [-1] + [1] * len(weight.shape[1:])
  145. _gamma = gamma.reshape(shape_list)
  146. _sigma = sigma.reshape(shape_list)
  147. elif gamma.shape[0] == weight.shape[1]:
  148. # `DepthwiseConv2d` op weight
  149. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  150. _gamma = gamma.reshape(shape_list)
  151. _sigma = sigma.reshape(shape_list)
  152. else:
  153. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  154. weight = weight * _gamma / _sigma
  155. bias = beta - gamma * mean / sigma
  156. return weight, bias
  157. def without_fold_batchnorm(weight, cell_quant):
  158. r"""
  159. Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
  160. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
  161. Args:
  162. weight (numpy.ndarray): Weight of `cell_quant`.
  163. cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
  164. Returns:
  165. weight (numpy.ndarray): whihout folded weight.
  166. bias (numpy.ndarray): without folded bias.
  167. """
  168. variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
  169. mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
  170. gamma = cell_quant.batchnorm.gamma.data.asnumpy()
  171. beta = cell_quant.batchnorm.beta.data.asnumpy()
  172. epsilon = cell_quant.batchnorm.eps
  173. sigma = np.sqrt(variance + epsilon)
  174. if gamma.shape[0] == weight.shape[0]:
  175. # `Conv2d` or `Dense` op weight
  176. shape_list = [-1] + [1] * len(weight.shape[1:])
  177. _gamma = gamma.reshape(shape_list)
  178. _sigma = sigma.reshape(shape_list)
  179. elif gamma.shape[0] == weight.shape[1]:
  180. # `DepthwiseConv2d` op weight
  181. shape_list = [1, -1] + [1] * len(weight.shape[2:])
  182. _gamma = gamma.reshape(shape_list)
  183. _sigma = sigma.reshape(shape_list)
  184. else:
  185. raise ValueError("Unsupported weight shape({})".format(weight.shape))
  186. weight = weight * _gamma / _sigma
  187. bias = beta - gamma * mean / sigma
  188. return weight, bias
  189. def compute_kl_threshold(data, bitwidth):
  190. r"""
  191. Using KL-J Distance to calculate the clip threshold.
  192. Args:
  193. - **data** (NumpyArray) - Data observed to calculate the threshold for quantization,
  194. - **bitwidth** (QuantDtype) - The datatype of quantization.
  195. Outputs:
  196. Tensor with Shape 1. Threshold to calculate the data.
  197. """
  198. data_max = np.abs(data).max()
  199. if data_max < 1e-5:
  200. return 1e-5
  201. hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True)
  202. # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the
  203. # largest size, turn to use the default bins config.
  204. largest_bin_size = 1024
  205. if hist.shape[0] > largest_bin_size:
  206. hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True)
  207. hist = hist / np.sum(hist)
  208. cumsum = np.cumsum(hist)
  209. bit_pow_range = pow(2, int(bitwidth.num_bits) - 1)
  210. threshold = []
  211. scaling_factor = []
  212. kl = []
  213. if bit_pow_range + 1 > len(bin_edges) - 1:
  214. th_layer_out = bin_edges[-1]
  215. return float(th_layer_out)
  216. for i in range(bit_pow_range + 1, len(bin_edges), 1):
  217. threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0])
  218. threshold = np.concatenate((threshold, [threshold_tmp]))
  219. scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1)
  220. scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp]))
  221. # forward interpolation
  222. cumsum_tmp = np.copy(cumsum)
  223. cumsum_tmp[(i - 1):] = 1
  224. fwd_x = np.linspace(0.0, 1.0, bit_pow_range)
  225. fwd_xp = np.linspace(0.0, 1.0, i)
  226. fwd_fp = cumsum_tmp[:i]
  227. forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp)
  228. # backward interpolation
  229. bwd_x = np.linspace(0.0, 1.0, i)
  230. bwd_xp = np.linspace(0.0, 1.0, bit_pow_range)
  231. bwd_fp = forward_interp
  232. backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp)
  233. cumsum_tmp[:i] = backward_interp
  234. kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J
  235. kl = np.concatenate((kl, [kl_tmp]))
  236. th_layer_out = threshold[np.argmin(kl)]
  237. threshold = float(th_layer_out)
  238. if threshold < 1e-5:
  239. threshold = 1e-5
  240. return threshold
  241. def query_quant_layers(network):
  242. r"""
  243. Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the
  244. quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized
  245. layers, which not exist in practical execution, may appear.
  246. Args:
  247. network (Cell): input network
  248. """
  249. network = Validator.check_isinstance("network", network, nn.Cell)
  250. tplt = "{0:60}\t{1:10}"
  251. for cell_and_name in network.cells_and_names():
  252. cell_name = cell_and_name[0]
  253. cell = cell_and_name[1]
  254. if isinstance(cell, nn.FakeQuantWithMinMaxObserver):
  255. print(tplt.format(cell_name, cell.quant_dtype))
  256. def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
  257. r"""
  258. Load fp32 model parameters into quantization model.
  259. Args:
  260. quant_model(Cell): Quantization model.
  261. params_dict(dict): Parameter dict that stores fp32 parameters.
  262. quant_new_params(list): Parameters that exist in quantization network but not in non-quantization
  263. network. Default: None.
  264. Raises:
  265. TypeError: If `quant_new_params` is not None and is not list.
  266. ValueError: If there are parameters in the `quant_model` that are neither in `params_dict`
  267. nor in `quant_new_params`.
  268. """
  269. if quant_new_params is not None and not isinstance(quant_new_params, list):
  270. raise TypeError("quant_new_params must be list or None.")
  271. iterable_dict = {
  272. 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))),
  273. 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))),
  274. 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items())))
  275. }
  276. for param in params_dict.items():
  277. key_name = param[0].split(".")[-1]
  278. if key_name not in iterable_dict:
  279. iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value),
  280. params_dict.items())))
  281. for name, param in quant_model.parameters_and_names():
  282. key_name = name.split(".")[-1]
  283. if key_name not in iterable_dict.keys():
  284. if key_name not in quant_new_params:
  285. raise ValueError(f"Can't find match parameter in ckpt, param name = {name}")
  286. continue
  287. value_param = next(iterable_dict[key_name], None)
  288. if value_param:
  289. param.set_data(value_param[1].data)
  290. print(f'init model param {name} with checkpoint param {value_param[0]}')
  291. # Perform KL_init when learned scale quantization is executed.
  292. for cell_and_name in quant_model.cells_and_names():
  293. cell = cell_and_name[1]
  294. if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant,
  295. nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE":
  296. subcell_weight_para = cell.weight.data.asnumpy()
  297. if hasattr(cell, 'gamma'):
  298. scale_factor = (cell.gamma.data.asnumpy() /
  299. np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5))
  300. subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
  301. if cell.fake_quant_weight.per_channel:
  302. max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype)
  303. for weight_para_each in subcell_weight_para]
  304. min_init = [-x for x in max_init]
  305. else:
  306. max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)]
  307. min_init = [-x for x in max_init]
  308. cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype,
  309. min_init=min_init, max_init=max_init)