You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qat.py 25 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Quantization aware training
  17. User can use quantization aware to train a model. MindSpore supports quantization aware training,
  18. which models quantization errors in both the forward and backward passes using fake-quantization
  19. operations. Note that the entire computation is carried out in floating point. At the end of quantization
  20. aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
  21. """
  22. import re
  23. import mindspore.context as context
  24. from ... import nn, ops
  25. from ..._checkparam import Validator, Rel
  26. from ...nn.layer import quant
  27. from ...ops import functional as F
  28. from ..common import QuantDtype
  29. from .quantizer import Quantizer, OptimizeOption
  30. __all__ = ["QuantizationAwareTraining", "create_quant_config"]
  31. def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
  32. quant_delay=(0, 0),
  33. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  34. per_channel=(False, False),
  35. symmetric=(False, False),
  36. narrow_range=(False, False)):
  37. r"""
  38. Configs the observer type of weights and data flow with quant params.
  39. Args:
  40. quant_observer (Observer, list or tuple): The observer type to do quantization. The first element represent
  41. weights and second element represent data flow.
  42. Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver)
  43. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
  44. eval. The first element represent weights and second element represent data flow. Default: (0, 0)
  45. quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
  46. element represent weights and second element represent data flow.
  47. Default: (QuantDtype.INT8, QuantDtype.INT8)
  48. per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
  49. then base on per channel otherwise base on per layer. The first element represent weights
  50. and second element represent data flow. Default: (False, False)
  51. symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
  52. symmetric otherwise base on asymmetric. The first element represent weights and second
  53. element represent data flow. Default: (False, False)
  54. narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
  55. The first element represents weights and the second element represents data flow. Default: (False, False)
  56. Returns:
  57. QuantConfig, Contains the observer type of weight and activation.
  58. """
  59. weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
  60. per_channel=per_channel[0], symmetric=symmetric[0],
  61. narrow_range=narrow_range[0])
  62. act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
  63. per_channel=per_channel[-1], symmetric=symmetric[-1],
  64. narrow_range=narrow_range[-1])
  65. return quant.QuantConfig(weight=weight_observer, activation=act_observer)
  66. class _AddFakeQuantInput(nn.Cell):
  67. """
  68. Add FakeQuant OP at input of the network. Only support one input case.
  69. """
  70. def __init__(self, network, quant_delay=0):
  71. super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
  72. self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
  73. quant_delay=quant_delay, ema=True)
  74. self.fake_quant_input.update_parameters_name('fake_quant_input.')
  75. self.network = network
  76. def construct(self, data):
  77. data = self.fake_quant_input(data)
  78. output = self.network(data)
  79. return output
  80. class _AddFakeQuantAfterSubCell(nn.Cell):
  81. """
  82. Add FakeQuant OP after of the sub Cell.
  83. """
  84. def __init__(self, subcell, **kwargs):
  85. super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
  86. self.subcell = subcell
  87. self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
  88. max_init=6,
  89. ema=True,
  90. quant_dtype=kwargs["quant_dtype"],
  91. quant_delay=kwargs["quant_delay"],
  92. per_channel=kwargs["per_channel"],
  93. symmetric=kwargs["symmetric"],
  94. narrow_range=kwargs["narrow_range"])
  95. def construct(self, *data):
  96. output = self.subcell(*data)
  97. output = self.fake_quant_act(output)
  98. return output
  99. class QuantizationAwareTraining(Quantizer):
  100. r"""
  101. Quantizer for quantization aware training.
  102. Args:
  103. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
  104. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
  105. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
  106. eval. The first element represent weights and second element represent data flow. Default: (0, 0)
  107. quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
  108. element represent weights and second element represent data flow.
  109. Default: (QuantDtype.INT8, QuantDtype.INT8)
  110. per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
  111. then base on per channel otherwise base on per layer. The first element represent weights
  112. and second element represent data flow. Default: (False, False)
  113. symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
  114. symmetric otherwise base on asymmetric. The first element represent weights and second
  115. element represent data flow. Default: (False, False)
  116. narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
  117. The first element represents weights and the second element represents data flow. Default: (False, False)
  118. optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
  119. support QAT. Default: OptimizeOption.QAT
  120. one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True.
  121. Examples:
  122. >>> class LeNet5(nn.Cell):
  123. ... def __init__(self, num_class=10, channel=1):
  124. ... super(LeNet5, self).__init__()
  125. ... self.type = "fusion"
  126. ... self.num_class = num_class
  127. ...
  128. ... # change `nn.Conv2d` to `nn.Conv2dBnAct`
  129. ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
  130. ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
  131. ... # change `nn.Dense` to `nn.DenseBnAct`
  132. ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
  133. ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
  134. ... self.fc3 = nn.DenseBnAct(84, self.num_class)
  135. ...
  136. ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  137. ... self.flatten = nn.Flatten()
  138. ...
  139. ... def construct(self, x):
  140. ... x = self.conv1(x)
  141. ... x = self.max_pool2d(x)
  142. ... x = self.conv2(x)
  143. ... x = self.max_pool2d(x)
  144. ... x = self.flatten(x)
  145. ... x = self.fc1(x)
  146. ... x = self.fc2(x)
  147. ... x = self.fc3(x)
  148. ... return x
  149. ...
  150. >>> net = LeNet5()
  151. >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
  152. >>> net_qat = quantizer.quantize(net)
  153. """
  154. __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
  155. def __init__(self,
  156. bn_fold=True,
  157. freeze_bn=10000000,
  158. quant_delay=(0, 0),
  159. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  160. per_channel=(False, False),
  161. symmetric=(False, False),
  162. narrow_range=(False, False),
  163. optimize_option=OptimizeOption.QAT,
  164. one_conv_fold=True):
  165. """Init for QuantizationAwareTraining quantizer"""
  166. super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
  167. def convert2list(name, value):
  168. if not isinstance(value, list) and not isinstance(value, tuple):
  169. value = [value]
  170. elif len(value) > 2:
  171. raise ValueError("input `{}` len should less then 2".format(name))
  172. return value
  173. quant_delay = convert2list("quant delay", quant_delay)
  174. quant_dtype = convert2list("quant dtype", quant_dtype)
  175. per_channel = convert2list("per channel", per_channel)
  176. symmetric = convert2list("symmetric", symmetric)
  177. narrow_range = convert2list("narrow range", narrow_range)
  178. self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
  179. self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
  180. self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
  181. self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
  182. self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
  183. self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
  184. self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
  185. self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
  186. self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
  187. self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
  188. self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
  189. self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
  190. self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold")
  191. self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
  192. nn.DenseBnAct: self._convert_dense}
  193. self.quant_config = create_quant_config(quant_delay=quant_delay,
  194. quant_dtype=quant_dtype,
  195. per_channel=per_channel,
  196. symmetric=symmetric,
  197. narrow_range=narrow_range)
  198. def _convert_op_name(self, name):
  199. pattern = re.compile(r'([A-Z]{1})')
  200. name_new = re.sub(pattern, r'_\1', name).lower()
  201. if name_new[0] == '_':
  202. name_new = name_new[1:]
  203. return name_new
  204. def quantize(self, network):
  205. """
  206. Quant API to convert input network to a quantization aware training network
  207. Args:
  208. network (Cell): network to be quantized.
  209. Examples:
  210. >>> net = Net()
  211. >>> quantizer = QuantizationAwareTraining()
  212. >>> net_qat = quantizer.quantize(net)
  213. """
  214. support_device = ["Ascend", "GPU"]
  215. if context.get_context('device_target') not in support_device:
  216. raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
  217. if OptimizeOption.QAT in self.optimize_option:
  218. network.update_cell_prefix()
  219. network = self._convert_subcells2quant(network)
  220. network.update_cell_type("quant")
  221. return network
  222. def _convert_subcells2quant(self, network):
  223. """
  224. convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
  225. """
  226. cells = network.name_cells()
  227. change = False
  228. for name in cells:
  229. subcell = cells[name]
  230. if subcell == network:
  231. continue
  232. elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
  233. prefix = subcell.param_prefix
  234. new_subcell = self._convert_method_map[type(subcell)](subcell)
  235. new_subcell.update_parameters_name(prefix + '.')
  236. network.insert_child_to_cell(name, new_subcell)
  237. change = True
  238. else:
  239. self._convert_subcells2quant(subcell)
  240. if isinstance(network, nn.SequentialCell) and change:
  241. network.cell_list = list(network.cells())
  242. # add FakeQuant OP after OP in while list
  243. add_list = []
  244. for name in network.__dict__:
  245. if name[0] == '_':
  246. continue
  247. attr = network.__dict__[name]
  248. if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
  249. add_list.append((name, attr))
  250. for name, prim_op in add_list:
  251. prefix = name
  252. add_quant = _AddFakeQuantAfterSubCell(prim_op,
  253. quant_dtype=self.act_dtype,
  254. quant_delay=self.act_qdelay,
  255. per_channel=self.act_channel,
  256. symmetric=self.act_symmetric,
  257. narrow_range=self.act_range)
  258. prefix = self._convert_op_name(prim_op.name)
  259. if network.param_prefix:
  260. prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
  261. add_quant.update_parameters_name(prefix + '.')
  262. del network.__dict__[name]
  263. network.insert_child_to_cell(name, add_quant)
  264. return network
  265. def _convert_conv(self, subcell):
  266. """
  267. convert Conv2d cell to quant cell
  268. """
  269. conv_inner = subcell.conv
  270. if subcell.has_bn:
  271. if self.bn_fold:
  272. bn_inner = subcell.batchnorm
  273. if self.one_conv_fold:
  274. conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels,
  275. conv_inner.out_channels,
  276. kernel_size=conv_inner.kernel_size,
  277. stride=conv_inner.stride,
  278. pad_mode=conv_inner.pad_mode,
  279. padding=conv_inner.padding,
  280. dilation=conv_inner.dilation,
  281. group=conv_inner.group,
  282. eps=bn_inner.eps,
  283. momentum=bn_inner.momentum,
  284. has_bias=conv_inner.has_bias,
  285. bias_init=conv_inner.bias_init,
  286. quant_config=self.quant_config,
  287. quant_dtype=self.weight_dtype,
  288. fake=True)
  289. else:
  290. conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
  291. conv_inner.out_channels,
  292. kernel_size=conv_inner.kernel_size,
  293. stride=conv_inner.stride,
  294. pad_mode=conv_inner.pad_mode,
  295. padding=conv_inner.padding,
  296. dilation=conv_inner.dilation,
  297. group=conv_inner.group,
  298. eps=bn_inner.eps,
  299. momentum=bn_inner.momentum,
  300. has_bias=conv_inner.has_bias,
  301. bias_init=conv_inner.bias_init,
  302. freeze_bn=self.freeze_bn,
  303. quant_config=self.quant_config,
  304. quant_dtype=self.weight_dtype,
  305. fake=True)
  306. # change original network BatchNormal OP parameters to quant network
  307. conv_inner.gamma = subcell.batchnorm.gamma
  308. conv_inner.beta = subcell.batchnorm.beta
  309. conv_inner.moving_mean = subcell.batchnorm.moving_mean
  310. conv_inner.moving_variance = subcell.batchnorm.moving_variance
  311. del subcell.batchnorm
  312. subcell.batchnorm = None
  313. subcell.has_bn = False
  314. else:
  315. bn_inner = subcell.batchnorm
  316. conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
  317. conv_inner.out_channels,
  318. kernel_size=conv_inner.kernel_size,
  319. stride=conv_inner.stride,
  320. pad_mode=conv_inner.pad_mode,
  321. padding=conv_inner.padding,
  322. dilation=conv_inner.dilation,
  323. group=conv_inner.group,
  324. eps=bn_inner.eps,
  325. momentum=bn_inner.momentum,
  326. has_bias=conv_inner.has_bias,
  327. bias_init=conv_inner.bias_init,
  328. quant_config=self.quant_config,
  329. quant_dtype=self.weight_dtype)
  330. # change original network BatchNormal OP parameters to quant network
  331. conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
  332. conv_inner.batchnorm.beta = subcell.batchnorm.beta
  333. conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
  334. conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
  335. del subcell.batchnorm
  336. subcell.batchnorm = None
  337. subcell.has_bn = False
  338. else:
  339. conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
  340. conv_inner.out_channels,
  341. kernel_size=conv_inner.kernel_size,
  342. stride=conv_inner.stride,
  343. pad_mode=conv_inner.pad_mode,
  344. padding=conv_inner.padding,
  345. dilation=conv_inner.dilation,
  346. group=conv_inner.group,
  347. has_bias=conv_inner.has_bias,
  348. quant_config=self.quant_config,
  349. quant_dtype=self.weight_dtype)
  350. # change original network Conv2D OP parameters to quant network
  351. conv_inner.weight = subcell.conv.weight
  352. if subcell.conv.has_bias:
  353. conv_inner.bias = subcell.conv.bias
  354. subcell.conv = conv_inner
  355. if subcell.has_act and subcell.activation is not None:
  356. subcell.activation = self._convert_activation(subcell.activation)
  357. elif subcell.after_fake:
  358. subcell.has_act = True
  359. subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
  360. quant_dtype=self.act_dtype,
  361. quant_delay=self.act_qdelay,
  362. per_channel=self.act_channel,
  363. symmetric=self.act_symmetric,
  364. narrow_range=self.act_range)
  365. return subcell
  366. def _convert_dense(self, subcell):
  367. """
  368. convert dense cell to quant cell
  369. """
  370. dense_inner = subcell.dense
  371. dense_inner = quant.DenseQuant(dense_inner.in_channels,
  372. dense_inner.out_channels,
  373. has_bias=dense_inner.has_bias,
  374. quant_config=self.quant_config,
  375. quant_dtype=self.weight_dtype)
  376. # change original network Dense OP parameters to quant network
  377. dense_inner.weight = subcell.dense.weight
  378. if subcell.dense.has_bias:
  379. dense_inner.bias = subcell.dense.bias
  380. subcell.dense = dense_inner
  381. if subcell.has_act and subcell.activation is not None:
  382. subcell.activation = self._convert_activation(subcell.activation)
  383. elif subcell.after_fake:
  384. subcell.has_act = True
  385. subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
  386. quant_dtype=self.act_dtype,
  387. quant_delay=self.act_qdelay,
  388. per_channel=self.act_channel,
  389. symmetric=self.act_symmetric,
  390. narrow_range=self.act_range)
  391. return subcell
  392. def _convert_activation(self, activation):
  393. """
  394. convert activation cell to quant cell
  395. """
  396. act_class = activation.__class__
  397. act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
  398. act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
  399. if act_class in act_list:
  400. return quant.ActQuant(activation=activation,
  401. quant_config=self.quant_config,
  402. quant_dtype=self.act_dtype)
  403. if act_class in act_list_with_fake_before:
  404. return quant.ActQuant(activation=activation,
  405. ema=True,
  406. fake_before=True,
  407. quant_config=self.quant_config,
  408. quant_dtype=self.act_dtype)
  409. raise ValueError("Unsupported activation in auto quant: ", act_class)