You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_quant.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Aware quantization."""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.common.dtype as mstype
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import functional as F
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.common.tensor import Tensor
  24. from mindspore._checkparam import check_int_positive, check_bool, twice
  25. from mindspore.nn.cell import Cell
  26. from mindspore.nn.layer.conv import _Conv
  27. from mindspore.nn.layer.activation import get_activation
  28. __all__ = [
  29. 'FakeQuantWithMinMax',
  30. 'Conv2dBatchNormQuant',
  31. 'Conv2dQuant',
  32. 'DenseQuant',
  33. 'ReLUQuant',
  34. 'ReLU6Quant',
  35. 'HSwishQuant',
  36. 'HSigmoidQuant',
  37. 'TensorAddQuant',
  38. ]
  39. class FakeQuantWithMinMax(Cell):
  40. r"""
  41. Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
  42. Args:
  43. min_init (int, list): The dimension of channel or 1(layer). Default: -6.
  44. max_init (int, list): The dimension of channel or 1(layer). Default: 6.
  45. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  46. ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
  47. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999.
  48. per_channel (bool): Quantization by layer or channel. Default: False.
  49. channel_size (int): declarate the min and max channel size, Default: 1.
  50. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  51. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  52. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  53. Inputs:
  54. - **x** (Tensor) - The input of FakeQuantWithMinMax.
  55. Outputs:
  56. Tensor, with the same type and shape as the `x`.
  57. """
  58. def __init__(self,
  59. min_init=-6,
  60. max_init=6,
  61. num_bits=8,
  62. ema=False,
  63. ema_decay=0.999,
  64. per_channel=False,
  65. channel_size=1,
  66. quant_delay=0,
  67. symmetric=False,
  68. narrow_range=False):
  69. super(FakeQuantWithMinMax, self).__init__()
  70. self.min_init = min_init
  71. self.num_bits = num_bits
  72. self.max_init = max_init
  73. self.ema = ema
  74. self.ema_decay = ema_decay
  75. self.per_channel = per_channel
  76. self.channel_size = channel_size
  77. self.quant_delay = quant_delay
  78. self.symmetric = symmetric
  79. self.narrow_range = narrow_range
  80. if per_channel:
  81. min_array = np.array([self.min_init for i in range(
  82. 0, self.channel_size)]).astype(np.float32)
  83. max_array = np.array([self.max_init for i in range(
  84. 0, self.channel_size)]).astype(np.float32)
  85. self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
  86. ema=self.ema,
  87. ema_decay=self.ema_decay,
  88. quant_delay=self.quant_delay,
  89. symmetric=self.symmetric,
  90. narrow_range=self.narrow_range,
  91. training=True)
  92. self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits,
  93. ema=self.ema,
  94. ema_decay=ema_decay,
  95. quant_delay=quant_delay,
  96. symmetric=self.symmetric,
  97. narrow_range=self.narrow_range,
  98. training=False)
  99. else:
  100. min_array = np.array([min_init]).reshape(1).astype(np.float32)
  101. max_array = np.array([max_init]).reshape(1).astype(np.float32)
  102. self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits,
  103. ema=self.ema,
  104. ema_decay=self.ema_decay,
  105. quant_delay=self.quant_delay,
  106. symmetric=self.symmetric,
  107. narrow_range=self.narrow_range,
  108. training=True)
  109. self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits,
  110. ema=self.ema,
  111. ema_decay=ema_decay,
  112. quant_delay=quant_delay,
  113. symmetric=self.symmetric,
  114. narrow_range=self.narrow_range,
  115. training=False)
  116. self.min = Parameter(
  117. Tensor(min_array), name='quant_min', requires_grad=False)
  118. self.max = Parameter(
  119. Tensor(max_array), name='quant_max', requires_grad=False)
  120. def extend_repr(self):
  121. s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format(
  122. self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size,
  123. self.quant_delay)
  124. return s
  125. def construct(self, x):
  126. if self.training:
  127. out = self.fake_quant_train(x, self.min, self.max)
  128. else:
  129. out = self.fake_quant_infer(x, self.min, self.max)
  130. return out
  131. class Conv2dBatchNormQuant(Cell):
  132. r"""
  133. 2D convolution with BatchNormal op folded layer.
  134. For a more Detailed overview of Conv2d op.
  135. Args:
  136. in_channels (int): The number of input channel :math:`C_{in}`.
  137. out_channels (int): The number of output channel :math:`C_{out}`.
  138. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  139. stride (int): Specifies stride for all spatial dimensions with the same value.
  140. pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  141. padding: (int): Implicit paddings on both sides of the input. Default: 0.
  142. eps (int): Parameters for BatchNormal. Default: 1e-5.
  143. momentum (int): Parameters for BatchNormal op. Default: 0.9.
  144. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  145. convolution kernel. Default: 'None'.
  146. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  147. beta vector. Default: 'None'.
  148. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  149. gamma vector. Default: 'None'.
  150. mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  151. mean vector. Default: 'None'.
  152. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  153. variance vector. Default: 'None'.
  154. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  155. freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
  156. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
  157. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  158. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
  159. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  160. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  161. Inputs:
  162. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  163. Outputs:
  164. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  165. """
  166. def __init__(self,
  167. in_channels,
  168. out_channels,
  169. kernel_size,
  170. stride,
  171. pad_mode,
  172. padding=0,
  173. eps=1e-5,
  174. momentum=0.9,
  175. weight_init=None,
  176. beta_init=None,
  177. gamma_init=None,
  178. mean_init=None,
  179. var_init=None,
  180. group=1,
  181. quant_delay=0,
  182. freeze_bn=100000,
  183. fake=True,
  184. num_bits=8,
  185. per_channel=False,
  186. symmetric=False,
  187. narrow_range=False):
  188. super(Conv2dBatchNormQuant, self).__init__()
  189. self.stride = stride
  190. self.conv = P.Conv2D(out_channel=out_channels,
  191. kernel_size=kernel_size,
  192. mode=1,
  193. pad_mode=pad_mode,
  194. pad=padding,
  195. stride=stride,
  196. dilation=1,
  197. group=group)
  198. self.fake = fake
  199. self.freeze_bn = freeze_bn
  200. if isinstance(kernel_size, int):
  201. kernel_size = (kernel_size, kernel_size)
  202. if weight_init is None:
  203. weight_init = initializer(
  204. 'normal', [out_channels, in_channels // group, *kernel_size])
  205. self.weight = Parameter(weight_init, name='weight')
  206. if gamma_init is None:
  207. gamma_init = initializer('ones', [out_channels])
  208. self.gamma = Parameter(gamma_init, name='gamma')
  209. if beta_init is None:
  210. beta_init = initializer('zeros', [out_channels])
  211. self.beta = Parameter(beta_init, name='beta')
  212. if mean_init is None:
  213. mean_init = initializer('zeros', [out_channels])
  214. self.moving_mean = Parameter(
  215. mean_init, name='moving_mean', requires_grad=False)
  216. if var_init is None:
  217. var_init = initializer('ones', [out_channels])
  218. self.moving_variance = Parameter(
  219. var_init, name='moving_variance', requires_grad=False)
  220. self.step = Parameter(initializer(
  221. 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
  222. self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
  223. max_init=6,
  224. ema=False,
  225. num_bits=num_bits,
  226. quant_delay=quant_delay,
  227. per_channel=per_channel,
  228. channel_size=out_channels,
  229. symmetric=symmetric,
  230. narrow_range=narrow_range)
  231. self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps,
  232. momentum=momentum,
  233. is_training=True,
  234. freeze_bn=freeze_bn)
  235. self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps,
  236. momentum=momentum,
  237. is_training=False,
  238. freeze_bn=freeze_bn)
  239. self.correct_mul = P.CorrectionMul()
  240. self.relu = P.ReLU()
  241. self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn)
  242. self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0)
  243. self.one = Tensor(1, mstype.int32)
  244. self.assignadd = P.AssignAdd()
  245. def extend_repr(self):
  246. s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn)
  247. return s
  248. def construct(self, x):
  249. if self.training:
  250. beta = self.beta
  251. gamma = self.gamma
  252. gmean = self.moving_mean
  253. gvar = self.moving_variance
  254. step = self.step
  255. out_conv = self.conv(x, self.weight)
  256. batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train(
  257. out_conv, gmean, gvar, step)
  258. # BN fold1
  259. weight = self.correct_mul(self.weight, gamma, running_std)
  260. if self.fake:
  261. weight = self.fake_quant_weight(weight)
  262. out = self.conv(x, weight)
  263. # BN fold2
  264. out = self.batchnorm_fold2(
  265. out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step)
  266. F.control_depend(out, self.assignadd(self.step, self.one))
  267. else:
  268. step = self.step
  269. out_conv = self.conv(x, self.weight)
  270. batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer(
  271. out_conv, self.moving_mean, self.moving_variance, step)
  272. weight = self.correct_mul(self.weight, self.gamma, running_std)
  273. if self.fake:
  274. weight = self.fake_quant_weight(weight)
  275. out = self.conv(x, weight)
  276. out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean,
  277. running_std, running_mean, step)
  278. return out
  279. class Conv2dQuant(_Conv):
  280. r"""
  281. 2D convolution with fake quant op layer.
  282. For a more Detailed overview of Conv2d op.
  283. Args:
  284. in_channels (int): The number of input channel :math:`C_{in}`.
  285. out_channels (int): The number of output channel :math:`C_{out}`.
  286. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  287. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
  288. pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  289. padding: (int): Implicit paddings on both sides of the input. Default: 0.
  290. dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1.
  291. group (int): Split filter into groups, `in_ channels` and `out_channels` should be
  292. divisible by the number of groups. Default: 1.
  293. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  294. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  295. Default: 'normal'.
  296. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
  297. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  298. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  299. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
  300. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  301. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  302. Inputs:
  303. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  304. Outputs:
  305. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  306. """
  307. def __init__(self,
  308. in_channels,
  309. out_channels,
  310. kernel_size,
  311. stride=1,
  312. pad_mode='same',
  313. padding=0,
  314. dilation=1,
  315. group=1,
  316. has_bias=False,
  317. weight_init='normal',
  318. bias_init='zeros',
  319. quant_delay=0,
  320. num_bits=8,
  321. per_channel=False,
  322. symmetric=False,
  323. narrow_range=False):
  324. kernel_size = twice(kernel_size)
  325. super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation,
  326. group, has_bias, weight_init, bias_init)
  327. self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1,
  328. pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation,
  329. group=self.group)
  330. self.bias_add = P.BiasAdd()
  331. if pad_mode not in ('valid', 'same', 'pad'):
  332. raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
  333. + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
  334. self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
  335. max_init=6,
  336. ema=False,
  337. num_bits=num_bits,
  338. quant_delay=quant_delay,
  339. per_channel=per_channel,
  340. channel_size=out_channels,
  341. symmetric=symmetric,
  342. narrow_range=narrow_range)
  343. def construct(self, x):
  344. weight_q = self.fake_quant_weight(self.weight)
  345. out = self.conv2d(x, weight_q)
  346. if self.has_bias:
  347. return self.bias_add(out, self.bias)
  348. return out
  349. class DenseQuant(Cell):
  350. r"""
  351. The fully connected layer with fake quant op.
  352. For a more Detailed overview of Dense op.
  353. Args:
  354. in_channels (int): The dimension of the input space.
  355. out_channels (int): The dimension of the output space.
  356. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  357. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  358. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  359. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  360. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  361. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
  362. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  363. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  364. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
  365. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  366. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  367. Inputs:
  368. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  369. Outputs:
  370. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  371. """
  372. def __init__(
  373. self,
  374. in_channels,
  375. out_channels,
  376. weight_init='normal',
  377. bias_init='zeros',
  378. has_bias=True,
  379. activation=None,
  380. num_bits=8,
  381. quant_delay=0,
  382. per_channel=False,
  383. symmetric=False,
  384. narrow_range=False):
  385. super(DenseQuant, self).__init__()
  386. self.in_channels = check_int_positive(in_channels)
  387. self.out_channels = check_int_positive(out_channels)
  388. self.has_bias = check_bool(has_bias)
  389. if isinstance(weight_init, Tensor):
  390. if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
  391. weight_init.shape()[1] != in_channels:
  392. raise ValueError("weight_init shape error")
  393. self.weight = Parameter(initializer(
  394. weight_init, [out_channels, in_channels]), name="weight")
  395. if self.has_bias:
  396. if isinstance(bias_init, Tensor):
  397. if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
  398. raise ValueError("bias_init shape error")
  399. self.bias = Parameter(initializer(
  400. bias_init, [out_channels]), name="bias")
  401. self.matmul = P.MatMul(transpose_b=True)
  402. self.bias_add = P.BiasAdd()
  403. self.activation = get_activation(activation)
  404. self.activation_flag = self.activation is not None
  405. self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6,
  406. max_init=6,
  407. ema=False,
  408. num_bits=num_bits,
  409. quant_delay=quant_delay,
  410. per_channel=per_channel,
  411. channel_size=out_channels,
  412. symmetric=symmetric,
  413. narrow_range=narrow_range)
  414. def construct(self, x):
  415. """Use operators to construct to Dense layer."""
  416. output = self.fake_quant_weight(self.weight)
  417. output = self.matmul(x, output)
  418. if self.has_bias:
  419. output = self.bias_add(output, self.bias)
  420. if self.activation_flag:
  421. return self.activation(output)
  422. return output
  423. def extend_repr(self):
  424. """A pretty print for Dense layer."""
  425. str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
  426. self.in_channels, self.out_channels, self.weight, self.has_bias)
  427. if self.has_bias:
  428. str_info = str_info + ', bias={}'.format(self.bias)
  429. if self.activation_flag:
  430. str_info = str_info + ', activation={}'.format(self.activation)
  431. return str_info
  432. class ReLUQuant(Cell):
  433. r"""
  434. ReLUQuant activation function. Add Fake Quant OP after Relu OP.
  435. For a more Detailed overview of ReLU op.
  436. Args:
  437. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  438. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  439. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  440. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  441. Inputs:
  442. - **x** (Tensor) - The input of ReLUQuant.
  443. Outputs:
  444. Tensor, with the same type and shape as the `x`.
  445. """
  446. def __init__(self,
  447. num_bits=8,
  448. quant_delay=0,
  449. symmetric=False,
  450. narrow_range=False):
  451. super(ReLUQuant, self).__init__()
  452. self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
  453. max_init=6,
  454. num_bits=num_bits,
  455. quant_delay=quant_delay,
  456. ema=True,
  457. symmetric=symmetric,
  458. narrow_range=narrow_range)
  459. self.relu = P.ReLU()
  460. def construct(self, x):
  461. x = self.relu(x)
  462. x = self.fake_quant_act(x)
  463. return x
  464. class ReLU6Quant(Cell):
  465. r"""
  466. ReLU6Quant activation function.
  467. Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
  468. Will climp the max range of the activation and the relu6 do the same operation.
  469. For a more Detailed overview of ReLU6 op.
  470. Args:
  471. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  472. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  473. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  474. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  475. Inputs:
  476. - **x** (Tensor) - The input of ReLU6Quant.
  477. Outputs:
  478. Tensor, with the same type and shape as the `x`.
  479. """
  480. def __init__(self, num_bits=8, quant_delay=0, symmetric=False,
  481. narrow_range=False):
  482. super(ReLU6Quant, self).__init__()
  483. self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0,
  484. max_init=6,
  485. num_bits=num_bits,
  486. quant_delay=quant_delay,
  487. ema=True,
  488. symmetric=symmetric,
  489. narrow_range=narrow_range)
  490. self.relu6 = P.ReLU6()
  491. def construct(self, x):
  492. x = self.relu6(x)
  493. x = self.fake_quant_act(x)
  494. return x
  495. class HSwishQuant(Cell):
  496. r"""
  497. HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
  498. For a more Detailed overview of HSwish op.
  499. Args:
  500. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  501. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  502. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  503. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  504. Inputs:
  505. - **x** (Tensor) - The input of HSwishQuant.
  506. Outputs:
  507. Tensor, with the same type and shape as the `x`.
  508. """
  509. def __init__(self,
  510. num_bits=8,
  511. quant_delay=0,
  512. symmetric=False,
  513. narrow_range=False):
  514. super(HSwishQuant, self).__init__()
  515. self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
  516. max_init=6,
  517. num_bits=num_bits,
  518. quant_delay=quant_delay,
  519. ema=True,
  520. symmetric=symmetric,
  521. narrow_range=narrow_range)
  522. self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
  523. max_init=6,
  524. num_bits=num_bits,
  525. quant_delay=quant_delay,
  526. ema=True,
  527. symmetric=symmetric,
  528. narrow_range=narrow_range)
  529. self.act = P.HSwish()
  530. def construct(self, x):
  531. x = self.fake_quant_act_before(x)
  532. x = self.act(x)
  533. x = self.fake_quant_act_after(x)
  534. return x
  535. class HSigmoidQuant(Cell):
  536. r"""
  537. HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
  538. For a more Detailed overview of HSigmoid op.
  539. Args:
  540. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  541. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  542. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  543. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  544. Inputs:
  545. - **x** (Tensor) - The input of HSigmoidQuant.
  546. Outputs:
  547. Tensor, with the same type and shape as the `x`.
  548. """
  549. def __init__(self,
  550. num_bits=8,
  551. quant_delay=0,
  552. symmetric=False,
  553. narrow_range=False):
  554. super(HSigmoidQuant, self).__init__()
  555. self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0,
  556. max_init=6,
  557. num_bits=num_bits,
  558. quant_delay=quant_delay,
  559. ema=True,
  560. symmetric=symmetric,
  561. narrow_range=narrow_range)
  562. self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0,
  563. max_init=6,
  564. num_bits=num_bits,
  565. quant_delay=quant_delay,
  566. ema=True,
  567. symmetric=symmetric,
  568. narrow_range=narrow_range)
  569. self.act = P.HSigmoid()
  570. def construct(self, x):
  571. x = self.fake_quant_act_before(x)
  572. x = self.act(x)
  573. x = self.fake_quant_act_after(x)
  574. return x
  575. class TensorAddQuant(Cell):
  576. r"""
  577. Add Fake Quant OP after TensorAdd OP.
  578. For a more Detailed overview of TensorAdd op.
  579. Args:
  580. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
  581. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
  582. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  583. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  584. Inputs:
  585. - **x** (Tensor) - The input of TensorAddQuant.
  586. Outputs:
  587. Tensor, with the same type and shape as the `x`.
  588. """
  589. def __init__(self,
  590. num_bits=8,
  591. quant_delay=0,
  592. symmetric=False,
  593. narrow_range=False):
  594. super(TensorAddQuant, self).__init__()
  595. self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6,
  596. max_init=6,
  597. num_bits=num_bits,
  598. quant_delay=quant_delay,
  599. ema=True,
  600. symmetric=symmetric,
  601. narrow_range=narrow_range)
  602. self.add = P.TensorAdd()
  603. def construct(self, x1, x2):
  604. x = self.add(x1, x2)
  605. x = self.fake_quant_act(x)
  606. return x