You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_export.py 12 kB

4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Export for quantization."""
  16. import copy
  17. import numpy as np
  18. from ... import nn, ops
  19. from ..._checkparam import Validator
  20. from ...common import Tensor
  21. from ...common import dtype as mstype
  22. from ...common.api import _executor
  23. from ...nn.layer import quant
  24. from ...ops import operations as P
  25. from ...ops.operations import _inner_ops as inner
  26. from ..quant import quant_utils
  27. from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell
  28. __all__ = ["ExportToQuantInferNetwork"]
  29. class ExportToQuantInferNetwork:
  30. """
  31. Convert quantization aware network to infer network.
  32. Args:
  33. network (Cell): MindSpore quantization aware training network.
  34. inputs (Tensor): Input tensors of the `quantization aware training network`.
  35. mean (int, float): The mean of input data after preprocessing, used for quantizing the first layer of network.
  36. Default: 127.5.
  37. std_dev (int, float): The variance of input data after preprocessing, used for quantizing the first layer
  38. of network. Default: 127.5.
  39. is_mindir (bool): Whether export MINDIR format. Default: False.
  40. Returns:
  41. Cell, Infer network.
  42. """
  43. __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
  44. def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
  45. network = Validator.check_isinstance('network', network, (nn.Cell,))
  46. self.input_scale = 1 / std_dev
  47. self.input_zero_point = round(mean)
  48. self.data_type = mstype.int8
  49. self.network = copy.deepcopy(network)
  50. self.network_bk = copy.deepcopy(network)
  51. self.all_parameters = {p.name: p for p in self.network.get_parameters()}
  52. self.get_inputs_table(inputs)
  53. self.mean = mean
  54. self.std_dev = std_dev
  55. self.is_mindir = is_mindir
  56. self.upcell = None
  57. self.upname = None
  58. def get_inputs_table(self, inputs):
  59. """Get the input quantization parameters of quantization cell for quant export."""
  60. phase_name = 'export_quant'
  61. graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
  62. self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
  63. def run(self):
  64. """Start to convert."""
  65. self.network.update_cell_prefix()
  66. network = self.network
  67. if isinstance(network, _AddFakeQuantInput):
  68. network = network.network
  69. network = self._convert_quant2deploy(network)
  70. return network
  71. def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
  72. """convert network's quant subcell to deploy subcell"""
  73. # Calculate the scale and zero point
  74. w_minq_name = cell_core.fake_quant_weight.minq.name
  75. w_maxq_name = cell_core.fake_quant_weight.maxq.name
  76. np_type = mstype.dtype_to_nptype(self.data_type)
  77. param_dict = dict()
  78. param_dict["filter_maxq"] = None
  79. param_dict["filter_minq"] = None
  80. param_dict["output_maxq"] = None
  81. param_dict["output_minq"] = None
  82. param_dict["input_maxq"] = None
  83. param_dict["input_minq"] = None
  84. param_dict["mean"] = self.mean
  85. param_dict["std_dev"] = self.std_dev
  86. param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
  87. scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
  88. quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
  89. if fake_quant_a_out is not None:
  90. _, _, param_dict["output_maxq"], param_dict["output_minq"] = \
  91. quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
  92. info = self.quant_info_table.get(w_minq_name, None)
  93. if not info:
  94. info = self.quant_info_table.get(w_maxq_name, None)
  95. if info:
  96. _, minq_name = info
  97. if minq_name == 'input':
  98. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  99. self.input_scale, self.input_zero_point, 'None', 'None'
  100. else:
  101. fake_quant_a_in_prefix = minq_name[:-5]
  102. cells = self.network_bk.cells_and_names()
  103. for cell in cells:
  104. if cell[0].endswith(fake_quant_a_in_prefix):
  105. fake_quant_a_in = cell[1]
  106. break
  107. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  108. quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
  109. else:
  110. # skip quant layer
  111. scale_a_in, zp_a_in = 1.0, 0.0
  112. # Build the `Quant` `Dequant` op.
  113. # Quant only support perlayer version. Need check here.
  114. quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
  115. scale_deq = scale_a_in * scale_w
  116. dequant_op = inner.Dequant()
  117. if isinstance(activation, _AddFakeQuantAfterSubCell):
  118. activation = activation.subcell
  119. elif hasattr(activation, "get_origin"):
  120. activation = activation.get_origin()
  121. # get the `weight` and `bias`
  122. weight = cell_core.weight.data.asnumpy()
  123. bias = None
  124. if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
  125. if cell_core.has_bias:
  126. bias = cell_core.bias.data.asnumpy()
  127. elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv)):
  128. weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
  129. elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
  130. weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
  131. weight_b = weight
  132. bias_b = bias
  133. # apply the quant
  134. weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, cell_core.fake_quant_weight.num_bits,
  135. cell_core.fake_quant_weight.narrow_range)
  136. if bias is not None:
  137. bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
  138. # fuse parameter
  139. # |--------|47:40|--------|39:32|--------|31:0|
  140. # offset_w [8] shift_N [8] deq_scale [32]
  141. float32_deq_scale = scale_deq.astype(np.float32)
  142. uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
  143. scale_length = scale_deq.size # channel
  144. dequant_param = np.zeros(scale_length, dtype=np.uint64)
  145. for index in range(scale_length):
  146. dequant_param[index] += uint32_deq_scale[index]
  147. scale_deq = Tensor(dequant_param, mstype.uint64)
  148. # get op
  149. if isinstance(cell_core, quant.DenseQuant):
  150. op_core = P.MatMul()
  151. weight = np.transpose(weight)
  152. weight_b = np.transpose(weight_b)
  153. else:
  154. op_core = cell_core.conv
  155. weight = Tensor(weight, self.data_type)
  156. weight_b = Tensor(weight_b)
  157. if bias_b is not None:
  158. bias_b = Tensor(bias_b, mstype.float32)
  159. if self.is_mindir:
  160. block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
  161. else:
  162. block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
  163. return block
  164. def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
  165. """add output quant info for quant op for export mindir."""
  166. if self.is_mindir:
  167. np_type = mstype.dtype_to_nptype(self.data_type)
  168. _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
  169. origin_op.add_prim_attr('output_maxq', Tensor(maxq))
  170. origin_op.add_prim_attr('output_minq', Tensor(minq))
  171. def _convert_quant2deploy(self, network):
  172. """Convert network's all quant subcell to deploy subcell."""
  173. cells = network.name_cells()
  174. change = False
  175. for name in cells:
  176. subcell = cells[name]
  177. if subcell == network:
  178. continue
  179. if isinstance(subcell, nn.Conv2dBnAct):
  180. network, change = self._convert_subcell(network, change, name, subcell)
  181. elif isinstance(subcell, nn.DenseBnAct):
  182. network, change = self._convert_subcell(network, change, name, subcell, conv=False)
  183. elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
  184. quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
  185. network, change = self._convert_subcell(network, change, name, subcell, core=False)
  186. elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
  187. if self.upcell:
  188. self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
  189. activation = subcell.get_origin()
  190. network.insert_child_to_cell(name, activation)
  191. change = True
  192. elif isinstance(subcell, nn.TensorAddQuant):
  193. if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
  194. add_op = subcell.add.subcell
  195. subcell.__delattr__("add")
  196. subcell.__setattr__("add", add_op)
  197. add_op = subcell.add
  198. if add_op:
  199. self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
  200. subcell.__delattr__("fake_quant_act")
  201. subcell.__setattr__("fake_quant_act", P.identity())
  202. elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
  203. if self.upcell:
  204. self._add_output_min_max_for_op(self.upcell.core_op, subcell)
  205. network.__delattr__(name)
  206. network.__setattr__(name, P.identity())
  207. elif isinstance(subcell, _AddFakeQuantAfterSubCell):
  208. op = subcell.subcell
  209. if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
  210. self._add_output_min_max_for_op(op, subcell.fake_quant_act)
  211. network.__delattr__(name)
  212. network.__setattr__(name, op)
  213. change = True
  214. else:
  215. self.upcell, self.upname = None, None
  216. self._convert_quant2deploy(subcell)
  217. if isinstance(network, nn.SequentialCell) and change:
  218. network.cell_list = list(network.cells())
  219. return network
  220. def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
  221. """Convert subcell to ant subcell."""
  222. new_subcell = None
  223. fake_quant_act = None
  224. if core:
  225. cell_core = subcell.conv if conv else subcell.dense
  226. activation = subcell.activation
  227. if hasattr(activation, 'fake_quant_act'):
  228. fake_quant_act = activation.fake_quant_act
  229. else:
  230. cell_core = subcell
  231. activation = None
  232. if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
  233. new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
  234. if new_subcell:
  235. prefix = subcell.param_prefix
  236. new_subcell.update_parameters_name(prefix + '.')
  237. self.upcell = None if core else new_subcell
  238. self.upname = None if core else name
  239. network.insert_child_to_cell(name, new_subcell)
  240. change = True
  241. return network, change