You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qat.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Quantization aware training
  17. User can use quantization aware to train a model. MindSpore supports quantization aware training,
  18. which models quantization errors in both the forward and backward passes using fake-quantization
  19. operations. Note that the entire computation is carried out in floating point. At the end of quantization
  20. aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
  21. """
  22. import re
  23. import mindspore.context as context
  24. import numpy as np
  25. from ... import nn, ops
  26. from ..._checkparam import Validator, Rel
  27. from ...nn.layer import quant
  28. from ...ops import functional as F
  29. from ..common import QuantDtype
  30. from .quantizer import Quantizer, OptimizeOption
  31. from .quant_utils import compute_KL_threshold
  32. __all__ = ["QuantizationAwareTraining", "create_quant_config"]
  33. def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
  34. quant_delay=(0, 0),
  35. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  36. per_channel=(False, False),
  37. symmetric=(False, False),
  38. narrow_range=(False, False),
  39. mode="DEFAULT"):
  40. r"""
  41. Config the observer type of weights and data flow with quant params.
  42. Args:
  43. quant_observer (Union[Observer, list, tuple]): The observer type to do quantization. The first element
  44. represents weights and second element represents data flow.
  45. Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver)
  46. quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during
  47. eval. The first element represents weights and second element represents data flow. Default: (0, 0)
  48. quant_dtype (Union[QuantDtype, list, tuple]): Datatype to use for quantize weights and activations. The first
  49. element represents weights and second element represents data flow.
  50. Default: (QuantDtype.INT8, QuantDtype.INT8)
  51. per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
  52. then base on per channel otherwise base on per layer. The first element represents weights
  53. and second element represents data flow, and second element must be `False` now. Default: (False, False)
  54. symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
  55. base on symmetric otherwise base on asymmetric. The first element represents weights and second
  56. element represents data flow. Default: (False, False)
  57. narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
  58. The first element represents weights and the second element represents data flow. Default: (False, False)
  59. mode (String): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported.
  60. Default: ("DEFAULT")
  61. Returns:
  62. QuantConfig, Contains the observer type of weight and activation.
  63. """
  64. if per_channel[-1]:
  65. raise ValueError("Arg 'per_channel' second element must be 'False'.")
  66. weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
  67. per_channel=per_channel[0], symmetric=symmetric[0],
  68. narrow_range=narrow_range[0], mode=mode)
  69. act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
  70. per_channel=per_channel[-1], symmetric=symmetric[-1],
  71. narrow_range=narrow_range[-1], mode=mode)
  72. return quant.QuantConfig(weight=weight_observer, activation=act_observer)
  73. class _AddFakeQuantInput(nn.Cell):
  74. """
  75. Add FakeQuant OP at input of the network. Only support one input case.
  76. """
  77. def __init__(self, network, quant_delay=0):
  78. super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
  79. self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
  80. quant_delay=quant_delay, ema=True)
  81. self.fake_quant_input.update_parameters_name('fake_quant_input.')
  82. self.network = network
  83. def construct(self, data):
  84. data = self.fake_quant_input(data)
  85. output = self.network(data)
  86. return output
  87. class _AddFakeQuantAfterSubCell(nn.Cell):
  88. """
  89. Add FakeQuant OP after of the sub Cell.
  90. """
  91. def __init__(self, subcell, **kwargs):
  92. super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
  93. self.subcell = subcell
  94. self.mode = "DEFAULT"
  95. self.max_init = 6
  96. self.min_init = -6
  97. if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]:
  98. self.mode = "LEARNED_SCALE"
  99. self.max_init = 16
  100. self.min_init = -16
  101. self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init,
  102. max_init=self.max_init,
  103. ema=True,
  104. quant_dtype=kwargs["quant_dtype"],
  105. quant_delay=kwargs["quant_delay"],
  106. per_channel=kwargs["per_channel"],
  107. symmetric=kwargs["symmetric"],
  108. narrow_range=kwargs["narrow_range"],
  109. mode=self.mode)
  110. def construct(self, *data):
  111. output = self.subcell(*data)
  112. output = self.fake_quant_act(output)
  113. return output
  114. class QuantizationAwareTraining(Quantizer):
  115. r"""
  116. Quantizer for quantization aware training.
  117. Args:
  118. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
  119. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
  120. quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during
  121. eval. The first element represents weights and second element represents data flow. Default: (0, 0)
  122. quant_dtype (Union[QuantDtype, list, tuple]): Datatype to use for quantize weights and activations. The first
  123. element represents weights and second element represents data flow. It is necessary to consider the
  124. precision support of hardware devices in the practical quantization infer scenario.
  125. Default: (QuantDtype.INT8, QuantDtype.INT8)
  126. per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True`
  127. then base on per channel otherwise base on per layer. The first element represents weights
  128. and second element represents data flow, and second element must be `False` now. Default: (False, False)
  129. symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
  130. base on symmetric otherwise base on asymmetric. The first element represents weights and second
  131. element represents data flow. Default: (False, False)
  132. narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
  133. The first element represents weights and the second element represents data flow. Default: (False, False)
  134. optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently only
  135. support QAT and LEARNED_SCALE (Note that, if both QAT and LEARNED_SCALE are configured, LEARNED_SCALE has
  136. a higher priority. LEARNED_SCALE currently only work under some constraints, which includes: freeze_bn=0,
  137. quant_delay=0, symmetric=Ture, narrow_range=True, More specifically, for operators such as ReLu and ReLu6,
  138. which only have positive values, we add a negative truncation to optimize this scenario, and narrow_range
  139. will automatically match to False). Default: OptimizeOption.QAT
  140. one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True.
  141. Examples:
  142. >>> class LeNet5(nn.Cell):
  143. ... def __init__(self, num_class=10, channel=1):
  144. ... super(LeNet5, self).__init__()
  145. ... self.type = "fusion"
  146. ... self.num_class = num_class
  147. ...
  148. ... # change `nn.Conv2d` to `nn.Conv2dBnAct`
  149. ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
  150. ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
  151. ... # change `nn.Dense` to `nn.DenseBnAct`
  152. ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
  153. ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
  154. ... self.fc3 = nn.DenseBnAct(84, self.num_class)
  155. ...
  156. ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  157. ... self.flatten = nn.Flatten()
  158. ...
  159. ... def construct(self, x):
  160. ... x = self.conv1(x)
  161. ... x = self.max_pool2d(x)
  162. ... x = self.conv2(x)
  163. ... x = self.max_pool2d(x)
  164. ... x = self.flatten(x)
  165. ... x = self.fc1(x)
  166. ... x = self.fc2(x)
  167. ... x = self.fc3(x)
  168. ... return x
  169. ...
  170. >>> net = LeNet5()
  171. >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
  172. >>> net_qat = quantizer.quantize(net)
  173. """
  174. __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
  175. def __init__(self,
  176. bn_fold=True,
  177. freeze_bn=10000000,
  178. quant_delay=(0, 0),
  179. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  180. per_channel=(False, False),
  181. symmetric=(False, False),
  182. narrow_range=(False, False),
  183. optimize_option=OptimizeOption.QAT,
  184. one_conv_fold=True):
  185. """Init for QuantizationAwareTraining quantizer"""
  186. super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
  187. def convert2list(name, value):
  188. if not isinstance(value, list) and not isinstance(value, tuple):
  189. value = [value]
  190. elif len(value) > 2:
  191. raise ValueError("input `{}` len should less then 2".format(name))
  192. return value
  193. quant_delay = convert2list("quant delay", quant_delay)
  194. quant_dtype = convert2list("quant dtype", quant_dtype)
  195. per_channel = convert2list("per channel", per_channel)
  196. symmetric = convert2list("symmetric", symmetric)
  197. narrow_range = convert2list("narrow range", narrow_range)
  198. self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
  199. self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
  200. self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
  201. self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
  202. self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
  203. self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
  204. self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
  205. self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
  206. self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
  207. self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
  208. self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
  209. self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
  210. self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold")
  211. self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
  212. nn.DenseBnAct: self._convert_dense}
  213. self.mode = "DEFAULT"
  214. if OptimizeOption.LEARNED_SCALE in self.optimize_option:
  215. self.mode = "LEARNED_SCALE"
  216. if not self.weight_symmetric or not self.act_symmetric:
  217. raise ValueError("OptimizeOption.LEARNED_SCALE currently only support "
  218. "symmetric=(True, True) for quant")
  219. if not self.weight_range or not self.act_range:
  220. raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) "
  221. "for quant")
  222. if self.freeze_bn != 0:
  223. raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, "
  224. "but get freeze_bn={}".format(self.freeze_bn))
  225. if self.weight_qdelay != 0 or self.act_qdelay != 0:
  226. raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)")
  227. self.quant_config = create_quant_config(quant_delay=quant_delay,
  228. quant_dtype=quant_dtype,
  229. per_channel=per_channel,
  230. symmetric=symmetric,
  231. narrow_range=narrow_range,
  232. mode=self.mode)
  233. self.eps = 1e-5
  234. def _convert_op_name(self, name):
  235. pattern = re.compile(r'([A-Z]{1})')
  236. name_new = re.sub(pattern, r'_\1', name).lower()
  237. if name_new[0] == '_':
  238. name_new = name_new[1:]
  239. return name_new
  240. def quantize(self, network):
  241. """
  242. Quant API to convert input network to a quantization aware training network
  243. Args:
  244. network (Cell): network to be quantized.
  245. Examples:
  246. >>> net = Net()
  247. >>> quantizer = QuantizationAwareTraining()
  248. >>> net_qat = quantizer.quantize(net)
  249. """
  250. support_device = ["Ascend", "GPU"]
  251. if context.get_context('device_target') not in support_device:
  252. raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
  253. if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option:
  254. network.update_cell_prefix()
  255. network = self._convert_subcells2quant(network)
  256. network.update_cell_type("quant")
  257. return network
  258. def _convert_subcells2quant(self, network):
  259. """
  260. convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
  261. """
  262. cells = network.name_cells()
  263. change = False
  264. for name in cells:
  265. subcell = cells[name]
  266. if subcell == network:
  267. continue
  268. elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
  269. prefix = subcell.param_prefix
  270. new_subcell = self._convert_method_map[type(subcell)](subcell)
  271. new_subcell.update_parameters_name(prefix + '.')
  272. network.insert_child_to_cell(name, new_subcell)
  273. change = True
  274. else:
  275. self._convert_subcells2quant(subcell)
  276. if isinstance(network, nn.SequentialCell) and change:
  277. network.cell_list = list(network.cells())
  278. # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell.
  279. if isinstance(network, (nn.FakeQuantWithMinMaxObserver,
  280. nn.Conv2dBnFoldQuantOneConv,
  281. nn.Conv2dBnFoldQuant,
  282. nn.Conv2dBnWithoutFoldQuant,
  283. nn.Conv2dQuant,
  284. nn.DenseQuant,
  285. nn.ActQuant,
  286. nn.TensorAddQuant,
  287. nn.MulQuant)):
  288. return network
  289. add_list = []
  290. for name in network.__dict__:
  291. if name[0] == '_':
  292. continue
  293. attr = network.__dict__[name]
  294. if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
  295. add_list.append((name, attr))
  296. for name, prim_op in add_list:
  297. prefix = name
  298. add_quant = _AddFakeQuantAfterSubCell(prim_op,
  299. quant_dtype=self.act_dtype,
  300. quant_delay=self.act_qdelay,
  301. per_channel=self.act_channel,
  302. symmetric=self.act_symmetric,
  303. narrow_range=self.act_range,
  304. optimize_option=self.optimize_option)
  305. prefix = self._convert_op_name(prim_op.name)
  306. if network.param_prefix:
  307. prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
  308. add_quant.update_parameters_name(prefix + '.')
  309. del network.__dict__[name]
  310. network.insert_child_to_cell(name, add_quant)
  311. return network
  312. def _convert_conv(self, subcell):
  313. """
  314. convert Conv2d cell to quant cell
  315. """
  316. min_init = -6
  317. max_init = 6
  318. if OptimizeOption.LEARNED_SCALE in self.optimize_option:
  319. subcell_weight_para = subcell.conv.weight.data.asnumpy()
  320. if subcell.has_bn:
  321. scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
  322. np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
  323. subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
  324. min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
  325. self.quant_config = self.quant_config._replace(
  326. weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
  327. conv_inner = subcell.conv
  328. if subcell.has_bn:
  329. bn_inner = subcell.batchnorm
  330. if self.bn_fold:
  331. if self.one_conv_fold:
  332. conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels,
  333. conv_inner.out_channels,
  334. kernel_size=conv_inner.kernel_size,
  335. stride=conv_inner.stride,
  336. pad_mode=conv_inner.pad_mode,
  337. padding=conv_inner.padding,
  338. dilation=conv_inner.dilation,
  339. group=conv_inner.group,
  340. eps=bn_inner.eps,
  341. momentum=1 - bn_inner.momentum,
  342. has_bias=conv_inner.has_bias,
  343. bias_init=conv_inner.bias_init,
  344. quant_config=self.quant_config,
  345. quant_dtype=self.weight_dtype,
  346. fake=True)
  347. else:
  348. conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
  349. conv_inner.out_channels,
  350. kernel_size=conv_inner.kernel_size,
  351. stride=conv_inner.stride,
  352. pad_mode=conv_inner.pad_mode,
  353. padding=conv_inner.padding,
  354. dilation=conv_inner.dilation,
  355. group=conv_inner.group,
  356. eps=bn_inner.eps,
  357. momentum=1 - bn_inner.momentum,
  358. has_bias=conv_inner.has_bias,
  359. bias_init=conv_inner.bias_init,
  360. freeze_bn=self.freeze_bn,
  361. quant_config=self.quant_config,
  362. quant_dtype=self.weight_dtype,
  363. fake=True)
  364. # change original network Batch Normalization OP parameters to quant network
  365. conv_inner.gamma = subcell.batchnorm.gamma
  366. conv_inner.beta = subcell.batchnorm.beta
  367. conv_inner.moving_mean = subcell.batchnorm.moving_mean
  368. conv_inner.moving_variance = subcell.batchnorm.moving_variance
  369. else:
  370. conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
  371. conv_inner.out_channels,
  372. kernel_size=conv_inner.kernel_size,
  373. stride=conv_inner.stride,
  374. pad_mode=conv_inner.pad_mode,
  375. padding=conv_inner.padding,
  376. dilation=conv_inner.dilation,
  377. group=conv_inner.group,
  378. eps=bn_inner.eps,
  379. momentum=1 - bn_inner.momentum,
  380. has_bias=conv_inner.has_bias,
  381. bias_init=conv_inner.bias_init,
  382. quant_config=self.quant_config,
  383. quant_dtype=self.weight_dtype)
  384. # change original network Batch Normalization OP parameters to quant network
  385. conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
  386. conv_inner.batchnorm.beta = subcell.batchnorm.beta
  387. conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
  388. conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
  389. del subcell.batchnorm
  390. subcell.batchnorm = None
  391. subcell.has_bn = False
  392. else:
  393. conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels,
  394. kernel_size=conv_inner.kernel_size, stride=conv_inner.stride,
  395. pad_mode=conv_inner.pad_mode, padding=conv_inner.padding,
  396. dilation=conv_inner.dilation, group=conv_inner.group,
  397. has_bias=conv_inner.has_bias, quant_config=self.quant_config,
  398. quant_dtype=self.weight_dtype)
  399. # change original network Conv2D OP parameters to quant network
  400. conv_inner.weight = subcell.conv.weight
  401. if subcell.conv.has_bias:
  402. conv_inner.bias = subcell.conv.bias
  403. subcell.conv = conv_inner
  404. if subcell.has_act and subcell.activation is not None:
  405. subcell.activation = self._convert_activation(subcell.activation)
  406. elif subcell.after_fake:
  407. subcell.has_act = True
  408. subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype,
  409. quant_delay=self.act_qdelay, per_channel=self.act_channel,
  410. symmetric=self.act_symmetric, narrow_range=self.act_range,
  411. optimize_option=self.optimize_option)
  412. return subcell
  413. def _convert_dense(self, subcell):
  414. """
  415. convert dense cell to quant cell
  416. """
  417. min_init = -6
  418. max_init = 6
  419. if OptimizeOption.LEARNED_SCALE in self.optimize_option:
  420. subcell_weight_para = subcell.dense.weight.data.asnumpy()
  421. if subcell.has_bn:
  422. scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
  423. np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
  424. subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
  425. min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
  426. self.quant_config = self.quant_config._replace(
  427. weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
  428. dense_inner = subcell.dense
  429. dense_inner = quant.DenseQuant(dense_inner.in_channels,
  430. dense_inner.out_channels,
  431. has_bias=dense_inner.has_bias,
  432. quant_config=self.quant_config,
  433. quant_dtype=self.weight_dtype)
  434. # change original network Dense OP parameters to quant network
  435. dense_inner.weight = subcell.dense.weight
  436. if subcell.dense.has_bias:
  437. dense_inner.bias = subcell.dense.bias
  438. subcell.dense = dense_inner
  439. if subcell.has_act and subcell.activation is not None:
  440. subcell.activation = self._convert_activation(subcell.activation)
  441. elif subcell.after_fake:
  442. subcell.has_act = True
  443. subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
  444. quant_dtype=self.act_dtype,
  445. quant_delay=self.act_qdelay,
  446. per_channel=self.act_channel,
  447. symmetric=self.act_symmetric,
  448. narrow_range=self.act_range,
  449. optimize_option=self.optimize_option)
  450. return subcell
  451. def _convert_activation(self, activation):
  452. """
  453. convert activation cell to quant cell
  454. """
  455. act_class = activation.__class__
  456. act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
  457. act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
  458. if act_class in act_list:
  459. return quant.ActQuant(activation=activation,
  460. quant_config=self.quant_config,
  461. quant_dtype=self.act_dtype)
  462. if act_class in act_list_with_fake_before:
  463. return quant.ActQuant(activation=activation,
  464. ema=True,
  465. fake_before=True,
  466. quant_config=self.quant_config,
  467. quant_dtype=self.act_dtype)
  468. raise ValueError("Unsupported activation in auto quant: ", act_class)
  469. def _KL_init(self, subcell_weight_para, weight_dtype):
  470. """
  471. Calculate the value of max_init and min_init with compute_KL_threshold.
  472. """
  473. if self.weight_channel:
  474. max_init = [compute_KL_threshold(weight_para_each, weight_dtype)
  475. for weight_para_each in subcell_weight_para]
  476. min_init = [-x for x in max_init]
  477. else:
  478. max_init = [compute_KL_threshold(subcell_weight_para, weight_dtype)]
  479. min_init = [-x for x in max_init]
  480. return min_init, max_init
  481. def set_mixed_bits(self, network, strategy):
  482. r"""
  483. Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE`
  484. optimize_option.
  485. Inputs:
  486. network (Cell): input network
  487. strategy (List): the quantization strategy for layers that need to be quantified (eg. [[8], [8],
  488. ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the
  489. convolution layer is supported.
  490. Outputs:
  491. network (Cell)
  492. """
  493. if OptimizeOption.LEARNED_SCALE not in self.optimize_option:
  494. raise ValueError("The `set_mixed_bits` function is currently only valid for `LEARNED_SCALE` "
  495. "optimize_option.")
  496. self.quantizable_idx = []
  497. pass_cell = None
  498. for i, cell_and_name in enumerate(network.cells_and_names()):
  499. cell = cell_and_name[1]
  500. if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell:
  501. self.quantizable_idx.append(i)
  502. assert len(self.quantizable_idx) == len(strategy)
  503. quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(self.quantizable_idx, strategy)}
  504. type_map = {
  505. QuantDtype.INT2.num_bits: QuantDtype.INT2,
  506. QuantDtype.INT3.num_bits: QuantDtype.INT3,
  507. QuantDtype.INT4.num_bits: QuantDtype.INT4,
  508. QuantDtype.INT5.num_bits: QuantDtype.INT5,
  509. QuantDtype.INT6.num_bits: QuantDtype.INT6,
  510. QuantDtype.INT7.num_bits: QuantDtype.INT7,
  511. QuantDtype.INT8.num_bits: QuantDtype.INT8
  512. }
  513. for i, cell_and_name in enumerate(network.cells_and_names()):
  514. cell = cell_and_name[1]
  515. if i not in self.quantizable_idx:
  516. continue
  517. else:
  518. if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)):
  519. cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]]
  520. if isinstance(cell, nn.Conv2dBnAct):
  521. subcell_weight_para = cell.conv.weight.data.asnumpy()
  522. if hasattr(cell.conv, 'gamma'):
  523. scale_factor = (cell.conv.gamma.data.asnumpy() /
  524. np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps))
  525. subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
  526. min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
  527. cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
  528. min_init=min_init,
  529. max_init=max_init)
  530. elif isinstance(cell, nn.DenseBnAct):
  531. subcell_weight_para = cell.dense.weight.data.asnumpy()
  532. if hasattr(cell.dense, 'gamma'):
  533. scale_factor = (cell.dense.gamma.data.asnumpy() /
  534. np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps))
  535. subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
  536. min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
  537. cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
  538. min_init=min_init,
  539. max_init=max_init)
  540. return network