You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qat.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Quantization aware training
  17. User can use quantization aware to train a model. MindSpore supports quantization aware training,
  18. which models quantization errors in both the forward and backward passes using fake-quantization
  19. operations. Note that the entire computation is carried out in floating point. At the end of quantization
  20. aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
  21. """
  22. import re
  23. import mindspore.context as context
  24. from ... import nn, ops
  25. from ..._checkparam import Validator, Rel
  26. from ...nn.layer import quant
  27. from ...ops import functional as F
  28. from ..common import QuantDtype
  29. from .quantizer import Quantizer, OptimizeOption
  30. __all__ = ["QuantizationAwareTraining"]
  31. _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
  32. nn.ReLU6: quant.ActQuant,
  33. nn.Sigmoid: quant.ActQuant,
  34. nn.LeakyReLU: quant.LeakyReLUQuant,
  35. nn.HSigmoid: quant.HSigmoidQuant,
  36. nn.HSwish: quant.HSwishQuant}
  37. def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
  38. quant_delay=(0, 0),
  39. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  40. per_channel=(False, False),
  41. symmetric=(False, False),
  42. narrow_range=(False, False)
  43. ):
  44. r"""
  45. Configs the oberser type of weights and data flow with quant params.
  46. Args:
  47. quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
  48. weights and second element represent data flow.
  49. Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
  50. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
  51. eval. The first element represent weights and second element represent data flow. Default: (0, 0)
  52. quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
  53. element represent weights and second element represent data flow.
  54. Default: (QuantDtype.INT8, QuantDtype.INT8)
  55. per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
  56. then base on per channel otherwise base on per layer. The first element represent weights
  57. and second element represent data flow. Default: (False, False)
  58. symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
  59. symmetric otherwise base on asymmetric. The first element represent weights and second
  60. element represent data flow. Default: (False, False)
  61. narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
  62. The first element represents weights and the second element represents data flow. Default: (False, False)
  63. Returns:
  64. QuantConfig, Contains the oberser type of weight and activation.
  65. """
  66. weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
  67. per_channel=per_channel[0], symmetric=symmetric[0],
  68. narrow_range=narrow_range[0])
  69. act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
  70. per_channel=per_channel[-1], symmetric=symmetric[-1],
  71. narrow_range=narrow_range[-1])
  72. return quant.QuantConfig(weight=weight_observer, activation=act_observer)
  73. class _AddFakeQuantInput(nn.Cell):
  74. """
  75. Add FakeQuant OP at input of the network. Only support one input case.
  76. """
  77. def __init__(self, network, quant_delay=0):
  78. super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
  79. self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
  80. quant_delay=quant_delay, ema=True)
  81. self.fake_quant_input.update_parameters_name('fake_quant_input.')
  82. self.network = network
  83. def construct(self, data):
  84. data = self.fake_quant_input(data)
  85. output = self.network(data)
  86. return output
  87. class _AddFakeQuantAfterSubCell(nn.Cell):
  88. """
  89. Add FakeQuant OP after of the sub Cell.
  90. """
  91. def __init__(self, subcell, **kwargs):
  92. super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
  93. self.subcell = subcell
  94. self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
  95. max_init=6,
  96. ema=True,
  97. quant_dtype=kwargs["quant_dtype"],
  98. quant_delay=kwargs["quant_delay"],
  99. per_channel=kwargs["per_channel"],
  100. symmetric=kwargs["symmetric"],
  101. narrow_range=kwargs["narrow_range"])
  102. def construct(self, *data):
  103. output = self.subcell(*data)
  104. output = self.fake_quant_act(output)
  105. return output
  106. class QuantizationAwareTraining(Quantizer):
  107. r"""
  108. Quantizer for quantization aware training.
  109. Args:
  110. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
  111. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
  112. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
  113. eval. The first element represent weights and second element represent data flow. Default: (0, 0)
  114. quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
  115. element represent weights and second element represent data flow.
  116. Default: (QuantDtype.INT8, QuantDtype.INT8)
  117. per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
  118. then base on per channel otherwise base on per layer. The first element represent weights
  119. and second element represent data flow. Default: (False, False)
  120. symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
  121. symmetric otherwise base on asymmetric. The first element represent weights and second
  122. element represent data flow. Default: (False, False)
  123. narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
  124. The first element represents weights and the second element represents data flow. Default: (False, False)
  125. optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
  126. support QAT. Default: OptimizeOption.QAT
  127. Examples:
  128. >>> class Net(nn.Cell):
  129. >>> def __init__(self, num_class=10, channel=1):
  130. >>> super(LeNet5, self).__init__()
  131. >>> self.type = "fusion"
  132. >>> self.num_class = num_class
  133. >>>
  134. >>> # change `nn.Conv2d` to `nn.Conv2dBnAct`
  135. >>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
  136. >>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
  137. >>> # change `nn.Dense` to `nn.DenseBnAct`
  138. >>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
  139. >>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
  140. >>> self.fc3 = nn.DenseBnAct(84, self.num_class)
  141. >>>
  142. >>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  143. >>> self.flatten = nn.Flatten()
  144. >>>
  145. >>> def construct(self, x):
  146. >>> x = self.conv1(x)
  147. >>> x = self.max_pool2d(x)
  148. >>> x = self.conv2(x)
  149. >>> x = self.max_pool2d(x)
  150. >>> x = self.flatten(x)
  151. >>> x = self.fc1(x)
  152. >>> x = self.fc2(x)
  153. >>> x = self.fc3(x)
  154. >>> return x
  155. >>>
  156. >>> net = Net()
  157. >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
  158. >>> net_qat = quantizer.quantize(net)
  159. """
  160. __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
  161. def __init__(self,
  162. bn_fold=True,
  163. freeze_bn=10000000,
  164. quant_delay=(0, 0),
  165. quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
  166. per_channel=(False, False),
  167. symmetric=(False, False),
  168. narrow_range=(False, False),
  169. optimize_option=OptimizeOption.QAT):
  170. """Init for QuantizationAwareTraining quantizer"""
  171. super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
  172. def convert2list(name, value):
  173. if not isinstance(value, list) and not isinstance(value, tuple):
  174. value = [value]
  175. elif len(value) > 2:
  176. raise ValueError("input `{}` len should less then 2".format(name))
  177. return value
  178. quant_delay = convert2list("quant delay", quant_delay)
  179. quant_dtype = convert2list("quant dtype", quant_dtype)
  180. per_channel = convert2list("per channel", per_channel)
  181. symmetric = convert2list("symmetric", symmetric)
  182. narrow_range = convert2list("narrow range", narrow_range)
  183. self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
  184. self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
  185. self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
  186. self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
  187. self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
  188. self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
  189. self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
  190. self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
  191. self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
  192. self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
  193. self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
  194. self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
  195. self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
  196. quant.DenseBnAct: self._convert_dense}
  197. self.quant_config = get_quant_config(quant_delay=quant_delay,
  198. quant_dtype=quant_dtype,
  199. per_channel=per_channel,
  200. symmetric=symmetric,
  201. narrow_range=narrow_range)
  202. def _convert_op_name(self, name):
  203. pattern = re.compile(r'([A-Z]{1})')
  204. name_new = re.sub(pattern, r'_\1', name).lower()
  205. if name_new[0] == '_':
  206. name_new = name_new[1:]
  207. return name_new
  208. def quantize(self, network):
  209. """
  210. Quant API to convert input network to a quantization aware training network
  211. Args:
  212. network (Cell): network to be quantized.
  213. Examples:
  214. >>> net = Net()
  215. >>> quantizer = QuantizationAwareTraining()
  216. >>> net_qat = quantizer.quantize(net)
  217. """
  218. support_device = ["Ascend", "GPU"]
  219. if context.get_context('device_target') not in support_device:
  220. raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
  221. if OptimizeOption.QAT in self.optimize_option:
  222. network.update_cell_prefix()
  223. network = self._convert_subcells2quant(network)
  224. network.update_cell_type("quant")
  225. return network
  226. def _convert_subcells2quant(self, network):
  227. """
  228. convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
  229. """
  230. cells = network.name_cells()
  231. change = False
  232. for name in cells:
  233. subcell = cells[name]
  234. if subcell == network:
  235. continue
  236. elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
  237. prefix = subcell.param_prefix
  238. new_subcell = self._convert_method_map[type(subcell)](subcell)
  239. new_subcell.update_parameters_name(prefix + '.')
  240. network.insert_child_to_cell(name, new_subcell)
  241. change = True
  242. else:
  243. self._convert_subcells2quant(subcell)
  244. if isinstance(network, nn.SequentialCell) and change:
  245. network.cell_list = list(network.cells())
  246. # add FakeQuant OP after OP in while list
  247. add_list = []
  248. for name in network.__dict__:
  249. if name[0] == '_':
  250. continue
  251. attr = network.__dict__[name]
  252. if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
  253. add_list.append((name, attr))
  254. for name, prim_op in add_list:
  255. prefix = name
  256. add_quant = _AddFakeQuantAfterSubCell(prim_op,
  257. quant_dtype=self.act_dtype,
  258. quant_delay=self.act_qdelay,
  259. per_channel=self.act_channel,
  260. symmetric=self.act_symmetric,
  261. narrow_range=self.act_range)
  262. prefix = self._convert_op_name(prim_op.name)
  263. if network.param_prefix:
  264. prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
  265. add_quant.update_parameters_name(prefix + '.')
  266. del network.__dict__[name]
  267. network.insert_child_to_cell(name, add_quant)
  268. return network
  269. def _convert_conv(self, subcell):
  270. """
  271. convert Conv2d cell to quant cell
  272. """
  273. conv_inner = subcell.conv
  274. if subcell.has_bn:
  275. if self.bn_fold:
  276. bn_inner = subcell.batchnorm
  277. conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
  278. conv_inner.out_channels,
  279. kernel_size=conv_inner.kernel_size,
  280. stride=conv_inner.stride,
  281. pad_mode=conv_inner.pad_mode,
  282. padding=conv_inner.padding,
  283. dilation=conv_inner.dilation,
  284. group=conv_inner.group,
  285. eps=bn_inner.eps,
  286. momentum=bn_inner.momentum,
  287. has_bias=conv_inner.has_bias,
  288. bias_init=conv_inner.bias_init,
  289. freeze_bn=self.freeze_bn,
  290. quant_config=self.quant_config,
  291. quant_dtype=self.weight_dtype,
  292. fake=True)
  293. # change original network BatchNormal OP parameters to quant network
  294. conv_inner.gamma = subcell.batchnorm.gamma
  295. conv_inner.beta = subcell.batchnorm.beta
  296. conv_inner.moving_mean = subcell.batchnorm.moving_mean
  297. conv_inner.moving_variance = subcell.batchnorm.moving_variance
  298. del subcell.batchnorm
  299. subcell.batchnorm = None
  300. subcell.has_bn = False
  301. else:
  302. bn_inner = subcell.batchnorm
  303. conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
  304. conv_inner.out_channels,
  305. kernel_size=conv_inner.kernel_size,
  306. stride=conv_inner.stride,
  307. pad_mode=conv_inner.pad_mode,
  308. padding=conv_inner.padding,
  309. dilation=conv_inner.dilation,
  310. group=conv_inner.group,
  311. eps=bn_inner.eps,
  312. momentum=bn_inner.momentum,
  313. has_bias=conv_inner.has_bias,
  314. bias_init=conv_inner.bias_init,
  315. quant_config=self.quant_config,
  316. quant_dtype=self.weight_dtype)
  317. # change original network BatchNormal OP parameters to quant network
  318. conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
  319. conv_inner.batchnorm.beta = subcell.batchnorm.beta
  320. conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
  321. conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
  322. del subcell.batchnorm
  323. subcell.batchnorm = None
  324. subcell.has_bn = False
  325. else:
  326. conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
  327. conv_inner.out_channels,
  328. kernel_size=conv_inner.kernel_size,
  329. stride=conv_inner.stride,
  330. pad_mode=conv_inner.pad_mode,
  331. padding=conv_inner.padding,
  332. dilation=conv_inner.dilation,
  333. group=conv_inner.group,
  334. has_bias=conv_inner.has_bias,
  335. quant_config=self.quant_config,
  336. quant_dtype=self.weight_dtype)
  337. # change original network Conv2D OP parameters to quant network
  338. conv_inner.weight = subcell.conv.weight
  339. if subcell.conv.has_bias:
  340. conv_inner.bias = subcell.conv.bias
  341. subcell.conv = conv_inner
  342. if subcell.has_act and subcell.activation is not None:
  343. subcell.activation = self._convert_activation(subcell.activation)
  344. elif subcell.after_fake:
  345. subcell.has_act = True
  346. subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
  347. quant_dtype=self.act_dtype,
  348. quant_delay=self.act_qdelay,
  349. per_channel=self.act_channel,
  350. symmetric=self.act_symmetric,
  351. narrow_range=self.act_range)
  352. return subcell
  353. def _convert_dense(self, subcell):
  354. """
  355. convert dense cell to combine dense cell
  356. """
  357. dense_inner = subcell.dense
  358. dense_inner = quant.DenseQuant(dense_inner.in_channels,
  359. dense_inner.out_channels,
  360. has_bias=dense_inner.has_bias,
  361. quant_config=self.quant_config,
  362. quant_dtype=self.weight_dtype)
  363. # change original network Dense OP parameters to quant network
  364. dense_inner.weight = subcell.dense.weight
  365. if subcell.dense.has_bias:
  366. dense_inner.bias = subcell.dense.bias
  367. subcell.dense = dense_inner
  368. if subcell.has_act and subcell.activation is not None:
  369. subcell.activation = self._convert_activation(subcell.activation)
  370. elif subcell.after_fake:
  371. subcell.has_act = True
  372. subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
  373. quant_dtype=self.act_dtype,
  374. quant_delay=self.act_qdelay,
  375. per_channel=self.act_channel,
  376. symmetric=self.act_symmetric,
  377. narrow_range=self.act_range)
  378. return subcell
  379. def _convert_activation(self, activation):
  380. act_class = activation.__class__
  381. if act_class not in _ACTIVATION_MAP:
  382. raise ValueError("Unsupported activation in auto quant: ", act_class)
  383. return _ACTIVATION_MAP[act_class](activation=activation,
  384. quant_config=self.quant_config,
  385. quant_dtype=self.act_dtype)