You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant_export.py 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Export for quantization."""
  16. import copy
  17. import numpy as np
  18. import mindspore.context as context
  19. from ... import log as logger
  20. from ... import nn, ops
  21. from ..._checkparam import Validator
  22. from ...common import Tensor
  23. from ...common import dtype as mstype
  24. from ...common.api import _executor
  25. from ...nn.layer import quant
  26. from ...ops import operations as P
  27. from ...ops.operations import _inner_ops as inner
  28. from ...train import serialization
  29. from ..quant import quant_utils
  30. from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell
  31. __all__ = ["export", "manual_export"]
  32. class ExportToQuantInferNetwork:
  33. """
  34. Convert quantization aware network to infer network.
  35. Args:
  36. network (Cell): MindSpore network API `convert_quant_network`.
  37. inputs (Tensor): Input tensors of the `quantization aware training network`.
  38. mean (int): Input data mean. Default: 127.5.
  39. std_dev (int, float): Input data variance. Default: 127.5.
  40. is_mindir (bool): Whether is MINDIR format. Default: False.
  41. Returns:
  42. Cell, Infer network.
  43. """
  44. __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
  45. def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
  46. network = Validator.check_isinstance('network', network, (nn.Cell,))
  47. self.input_scale = 1 / std_dev
  48. self.input_zero_point = round(mean)
  49. self.data_type = mstype.int8
  50. self.network = copy.deepcopy(network)
  51. self.all_parameters = {p.name: p for p in self.network.get_parameters()}
  52. self.get_inputs_table(inputs)
  53. self.mean = mean
  54. self.std_dev = std_dev
  55. self.is_mindir = is_mindir
  56. def get_inputs_table(self, inputs):
  57. """Get the support info for quant export."""
  58. phase_name = 'export_quant'
  59. graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
  60. self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
  61. def run(self):
  62. """Start to convert."""
  63. self.network.update_cell_prefix()
  64. network = self.network
  65. if isinstance(network, _AddFakeQuantInput):
  66. network = network.network
  67. network = self._convert_quant2deploy(network)
  68. return network
  69. def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
  70. """convet network's quant subcell to deploy subcell"""
  71. # Calculate the scale and zero point
  72. w_minq_name = cell_core.fake_quant_weight.minq.name
  73. np_type = mstype.dtype_to_nptype(self.data_type)
  74. param_dict = dict()
  75. param_dict["filter_maxq"] = None
  76. param_dict["filter_minq"] = None
  77. param_dict["output_maxq"] = None
  78. param_dict["output_minq"] = None
  79. param_dict["input_maxq"] = None
  80. param_dict["input_minq"] = None
  81. param_dict["mean"] = self.mean
  82. param_dict["std_dev"] = self.std_dev
  83. param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
  84. scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
  85. quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
  86. if fake_quant_a_out is not None:
  87. _, _, param_dict["output_maxq"], param_dict["output_minq"] = \
  88. quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
  89. info = self.quant_info_table.get(w_minq_name, None)
  90. if info:
  91. fake_quant_a_in_op, minq_name = info
  92. if minq_name == 'input':
  93. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  94. self.input_scale, self.input_zero_point, 'None', 'None'
  95. else:
  96. maxq = self.all_parameters[minq_name[:-4] + "maxq"]
  97. minq = self.all_parameters[minq_name]
  98. if self.is_mindir:
  99. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  100. quant_utils.scale_zp_max_min_from_data(fake_quant_a_in_op, minq, maxq, np_type)
  101. else:
  102. scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fake_quant_a_in_op, minq, maxq, np_type)
  103. else:
  104. logger.warning(f"Can not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
  105. return None
  106. # Build the `Quant` `Dequant` op.
  107. # Quant only support perlayer version. Need check here.
  108. quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
  109. scale_deq = scale_a_in * scale_w
  110. dequant_op = inner.Dequant()
  111. if isinstance(activation, _AddFakeQuantAfterSubCell):
  112. activation = activation.subcell
  113. elif hasattr(activation, "get_origin"):
  114. activation = activation.get_origin()
  115. # get the `weight` and `bias`
  116. weight = cell_core.weight.data.asnumpy()
  117. bias = None
  118. if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
  119. if cell_core.has_bias:
  120. bias = cell_core.bias.data.asnumpy()
  121. elif isinstance(cell_core, quant.Conv2dBnFoldQuant):
  122. weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
  123. elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
  124. weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
  125. weight_b = weight
  126. bias_b = bias
  127. # apply the quant
  128. fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer
  129. weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits,
  130. fake_quant_weight_op.narrow_range)
  131. if bias is not None:
  132. bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
  133. # fuse parameter
  134. # |--------|47:40|--------|39:32|--------|31:0|
  135. # offset_w [8] shift_N [8] deq_scale [32]
  136. float32_deq_scale = scale_deq.astype(np.float32)
  137. uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
  138. scale_length = scale_deq.size # channel
  139. dequant_param = np.zeros(scale_length, dtype=np.uint64)
  140. for index in range(scale_length):
  141. dequant_param[index] += uint32_deq_scale[index]
  142. scale_deq = Tensor(dequant_param, mstype.uint64)
  143. # get op
  144. if isinstance(cell_core, quant.DenseQuant):
  145. op_core = P.MatMul()
  146. weight = np.transpose(weight)
  147. weight_b = np.transpose(weight_b)
  148. else:
  149. op_core = cell_core.conv
  150. weight = Tensor(weight, self.data_type)
  151. weight_b = Tensor(weight_b)
  152. if bias_b is not None:
  153. bias_b = Tensor(bias_b, mstype.float32)
  154. if self.is_mindir:
  155. block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
  156. else:
  157. block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
  158. return block
  159. def _convert_quant2deploy(self, network):
  160. """Convert network's all quant subcell to deploy subcell."""
  161. cells = network.name_cells()
  162. change = False
  163. for name in cells:
  164. subcell = cells[name]
  165. if subcell == network:
  166. continue
  167. cell_core = None
  168. fake_quant_act = None
  169. activation = None
  170. if isinstance(subcell, quant.Conv2dBnAct):
  171. cell_core = subcell.conv
  172. activation = subcell.activation
  173. fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
  174. elif isinstance(subcell, quant.DenseBnAct):
  175. cell_core = subcell.dense
  176. activation = subcell.activation
  177. fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
  178. if cell_core is not None:
  179. new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
  180. if new_subcell:
  181. prefix = subcell.param_prefix
  182. new_subcell.update_parameters_name(prefix + '.')
  183. network.insert_child_to_cell(name, new_subcell)
  184. change = True
  185. elif isinstance(subcell, _AddFakeQuantAfterSubCell):
  186. op = subcell.subcell
  187. if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
  188. if self.is_mindir:
  189. op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
  190. op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
  191. network.__delattr__(name)
  192. network.__setattr__(name, op)
  193. change = True
  194. else:
  195. self._convert_quant2deploy(subcell)
  196. if isinstance(network, nn.SequentialCell) and change:
  197. network.cell_list = list(network.cells())
  198. return network
  199. def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='AIR'):
  200. """
  201. Exports MindSpore quantization predict model to deploy with AIR.
  202. Args:
  203. network (Cell): MindSpore network produced by `convert_quant_network`.
  204. inputs (Tensor): Inputs of the `quantization aware training network`.
  205. file_name (str): File name of model to export.
  206. mean (int, float): Input data mean. Default: 127.5.
  207. std_dev (int, float): Input data variance. Default: 127.5.
  208. file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
  209. quantization aware model. Default: 'AIR'.
  210. - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
  211. Ascend model.
  212. - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
  213. for MindSpore models.
  214. Recommended suffix for output file is '.mindir'.
  215. """
  216. supported_device = ["Ascend", "GPU"]
  217. supported_formats = ['AIR', 'MINDIR']
  218. mean = Validator.check_type("mean", mean, (int, float))
  219. std_dev = Validator.check_type("std_dev", std_dev, (int, float))
  220. if context.get_context('device_target') not in supported_device:
  221. raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
  222. if file_format not in supported_formats:
  223. raise ValueError('Illegal file format {}.'.format(file_format))
  224. network.set_train(False)
  225. if file_format == "MINDIR":
  226. exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
  227. else:
  228. exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
  229. deploy_net = exporter.run()
  230. serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
  231. class ExportManualQuantNetwork:
  232. """
  233. Convert anual quantization aware network to infer network.
  234. Args:
  235. network (Cell): MindSpore network API `convert_quant_network`.
  236. inputs (Tensor): Input tensors of the `quantization aware training network`.
  237. mean (int): Input data mean. Default: 127.5.
  238. std_dev (int, float): Input data variance. Default: 127.5.
  239. is_mindir (bool): Whether is MINDIR format. Default: False.
  240. Returns:
  241. Cell, Infer network.
  242. """
  243. __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
  244. def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
  245. network = Validator.check_isinstance('network', network, (nn.Cell,))
  246. self.input_scale = 1 / std_dev
  247. self.input_zero_point = round(mean)
  248. self.data_type = mstype.int8
  249. self.network = copy.deepcopy(network)
  250. self.all_parameters = {p.name: p for p in self.network.get_parameters()}
  251. self.get_inputs_table(inputs)
  252. self.mean = mean
  253. self.std_dev = std_dev
  254. self.is_mindir = is_mindir
  255. self.upcell = None
  256. self.upname = None
  257. def get_inputs_table(self, inputs):
  258. """Get the support info for quant export."""
  259. phase_name = 'export_quant'
  260. graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
  261. self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
  262. def run(self):
  263. """Start to convert."""
  264. self.network.update_cell_prefix()
  265. network = self.network
  266. if isinstance(network, _AddFakeQuantInput):
  267. network = network.network
  268. network = self._convert_manual_network(network)
  269. return network
  270. def _convert_manual_network(self, network):
  271. """Convert network's all quant subcell to deploy subcell."""
  272. cells = network.name_cells()
  273. change = False
  274. for name in cells:
  275. subcell = cells[name]
  276. if subcell == network:
  277. continue
  278. if isinstance(subcell, quant.Conv2dBnAct):
  279. network, change = self._convert_subcell(network, change, name, subcell)
  280. elif isinstance(subcell, quant.DenseBnAct):
  281. network, change = self._convert_subcell(network, change, name, subcell, conv=False)
  282. elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
  283. quant.Conv2dQuant, quant.DenseQuant)):
  284. network, change = self._convert_subcell(network, change, name, subcell, core=False)
  285. elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell:
  286. np_type = mstype.dtype_to_nptype(self.data_type)
  287. _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type)
  288. self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq))
  289. self.upcell.core_op.add_prim_attr('output_minq', Tensor(minq))
  290. network.insert_child_to_cell(self.upname, self.upcell)
  291. elif isinstance(subcell, _AddFakeQuantAfterSubCell):
  292. op = subcell.subcell
  293. if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
  294. if self.is_mindir:
  295. op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
  296. op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
  297. network.__delattr__(name)
  298. network.__setattr__(name, op)
  299. change = True
  300. else:
  301. self.upcell, self.upname = None, None
  302. self._convert_manual_network(subcell)
  303. if isinstance(network, nn.SequentialCell) and change:
  304. network.cell_list = list(network.cells())
  305. return network
  306. def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
  307. """Convert subcell to ant subcell."""
  308. if core:
  309. cell_core = subcell.conv if conv else subcell.dense
  310. activation = subcell.activation
  311. fake_quant_act = activation.fake_quant_act
  312. else:
  313. cell_core = subcell
  314. activation = None
  315. fake_quant_act = None
  316. new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
  317. if new_subcell:
  318. prefix = subcell.param_prefix
  319. new_subcell.update_parameters_name(prefix + '.')
  320. self.upcell = None if core else new_subcell
  321. self.upname = None if core else name
  322. network.insert_child_to_cell(name, new_subcell)
  323. change = True
  324. return network, change
  325. def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
  326. """convet network's quant subcell to deploy subcell"""
  327. w_minq_name = cell_core.fake_quant_weight.minq.name
  328. np_type = mstype.dtype_to_nptype(self.data_type)
  329. param_dict = dict()
  330. param_dict["filter_maxq"] = None
  331. param_dict["filter_minq"] = None
  332. param_dict["output_maxq"] = None
  333. param_dict["output_minq"] = None
  334. param_dict["input_maxq"] = None
  335. param_dict["input_minq"] = None
  336. param_dict["mean"] = self.mean
  337. param_dict["std_dev"] = self.std_dev
  338. param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
  339. scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
  340. quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
  341. if fake_quant_a_out is not None:
  342. _, _, param_dict["output_maxq"], param_dict["output_minq"] = \
  343. quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
  344. info = self.quant_info_table.get(w_minq_name, None)
  345. if info:
  346. fack_quant_a_in_op, minq_name = info
  347. if minq_name == 'input':
  348. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  349. self.input_scale, self.input_zero_point, 'None', 'None'
  350. else:
  351. maxq = self.all_parameters[minq_name[:-4] + "maxq"]
  352. minq = self.all_parameters[minq_name]
  353. scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
  354. quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type)
  355. else:
  356. # skip quant layer
  357. scale_a_in, zp_a_in = 1, 0
  358. # Build the `Quant` `Dequant` op.
  359. # Quant only support perlayer version. Need check here.
  360. quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
  361. scale_deq = scale_a_in * scale_w
  362. dequant_op = inner.Dequant()
  363. if isinstance(activation, _AddFakeQuantAfterSubCell):
  364. activation = activation.subcell
  365. elif hasattr(activation, "get_origin"):
  366. activation = activation.get_origin()
  367. # get the `weight` and `bias`
  368. weight = cell_core.weight.data.asnumpy()
  369. bias = None
  370. if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
  371. if cell_core.has_bias:
  372. bias = cell_core.bias.data.asnumpy()
  373. elif isinstance(cell_core, quant.Conv2dBnFoldQuant):
  374. weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
  375. elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
  376. weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
  377. weight_b = weight
  378. bias_b = bias
  379. # apply the quant
  380. fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer
  381. weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits,
  382. fake_quant_weight_op.narrow_range)
  383. if bias is not None:
  384. bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
  385. float32_deq_scale = scale_deq.astype(np.float32)
  386. uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
  387. scale_length = scale_deq.size # channel
  388. dequant_param = np.zeros(scale_length, dtype=np.uint64)
  389. for index in range(scale_length):
  390. dequant_param[index] += uint32_deq_scale[index]
  391. scale_deq = Tensor(dequant_param, mstype.uint64)
  392. # get op
  393. if isinstance(cell_core, quant.DenseQuant):
  394. op_core = P.MatMul()
  395. weight = np.transpose(weight)
  396. weight_b = np.transpose(weight_b)
  397. else:
  398. op_core = cell_core.conv
  399. weight = Tensor(weight, self.data_type)
  400. weight_b = Tensor(weight_b)
  401. if bias_b is not None:
  402. bias_b = Tensor(bias_b, mstype.float32)
  403. if self.is_mindir:
  404. block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
  405. else:
  406. block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
  407. return block
  408. def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'):
  409. """
  410. Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR.
  411. Args:
  412. network (Cell): MindSpore network produced by `convert_quant_network`.
  413. inputs (Tensor): Inputs of the `quantization aware training network`.
  414. file_name (str): File name of model to export.
  415. mean (int, float): Input data mean. Default: 127.5.
  416. std_dev (int, float): Input data variance. Default: 127.5.
  417. file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
  418. quantization aware model. Default: 'AIR'.
  419. - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
  420. Ascend model.
  421. - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
  422. for MindSpore models.
  423. Recommended suffix for output file is '.mindir'.
  424. """
  425. supported_device = ["Ascend", "GPU"]
  426. supported_formats = ['AIR', 'MINDIR']
  427. mean = Validator.check_type("mean", mean, (int, float))
  428. std_dev = Validator.check_type("std_dev", std_dev, (int, float))
  429. if context.get_context('device_target') not in supported_device:
  430. raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
  431. if file_format not in supported_formats:
  432. raise ValueError('Illegal file format {}.'.format(file_format))
  433. network.set_train(False)
  434. if file_format == "MINDIR":
  435. exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
  436. else:
  437. exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False)
  438. deploy_net = exporter.run()
  439. serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)