You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant.py 65 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization aware training."""
  16. from functools import partial
  17. from collections import namedtuple
  18. import numpy as np
  19. from mindspore import nn
  20. import mindspore.common.dtype as mstype
  21. from mindspore.ops.primitive import Primitive
  22. from mindspore.ops import operations as P
  23. from mindspore.ops import functional as F
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.common.initializer import initializer
  26. from mindspore.common.tensor import Tensor
  27. from mindspore._checkparam import Validator, Rel, twice
  28. from mindspore.compression.common import QuantDtype
  29. import mindspore.context as context
  30. from .normalization import BatchNorm2d, BatchNorm1d
  31. from .activation import get_activation, ReLU, LeakyReLU
  32. from ..cell import Cell
  33. from ...ops.operations import _quant_ops as Q
  34. __all__ = [
  35. 'Conv2dBnAct',
  36. 'DenseBnAct',
  37. 'FakeQuantWithMinMaxObserver',
  38. 'Conv2dBnFoldQuant',
  39. 'Conv2dBnWithoutFoldQuant',
  40. 'Conv2dQuant',
  41. 'DenseQuant',
  42. 'ActQuant',
  43. 'LeakyReLUQuant',
  44. 'HSwishQuant',
  45. 'HSigmoidQuant',
  46. 'TensorAddQuant',
  47. 'MulQuant',
  48. ]
  49. class Conv2dBnAct(Cell):
  50. r"""
  51. A combination of convolution, Batchnorm, activation layer.
  52. This part is a more detailed overview of Conv2d op.
  53. Args:
  54. in_channels (int): The number of input channel :math:`C_{in}`.
  55. out_channels (int): The number of output channel :math:`C_{out}`.
  56. kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height
  57. and width of the 2D convolution window. Single int means the value is for both height and width of
  58. the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
  59. width of the kernel.
  60. stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be
  61. greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1.
  62. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  63. padding (int): Implicit paddings on both sides of the input. Default: 0.
  64. dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
  65. there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than
  66. or equal to 1 and lower than any one of the height and width of the input. Default: 1.
  67. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  68. divisible by the number of groups. Default: 1.
  69. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  70. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  71. It can be a Tensor, a string, an Initializer or a number. When a string is specified,
  72. values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
  73. as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
  74. and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
  75. Initializer for more details. Default: 'normal'.
  76. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
  77. Initializer and string are the same as 'weight_init'. Refer to the values of
  78. Initializer for more details. Default: 'zeros'.
  79. has_bn (bool): Specifies to used batchnorm or not. Default: False.
  80. momentum (float): Momentum for moving average.Momentum value must be [0, 1].Default:0.9
  81. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
  82. 1e-5.
  83. activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
  84. 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
  85. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
  86. alpha (float): Slope of the activation function at x < 0. Default: 0.2.
  87. after_fake(bool): Determin whether there must be a fake quantization operation after Cond2dBnAct.
  88. Inputs:
  89. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  90. Outputs:
  91. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  92. Examples:
  93. >>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
  94. >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
  95. >>> net(input).shape
  96. (1, 240, 1024, 640)
  97. """
  98. def __init__(self,
  99. in_channels,
  100. out_channels,
  101. kernel_size,
  102. stride=1,
  103. pad_mode='same',
  104. padding=0,
  105. dilation=1,
  106. group=1,
  107. has_bias=False,
  108. weight_init='normal',
  109. bias_init='zeros',
  110. has_bn=False,
  111. momentum=0.9,
  112. eps=1e-5,
  113. activation=None,
  114. alpha=0.2,
  115. after_fake=True):
  116. super(Conv2dBnAct, self).__init__()
  117. self.conv = nn.Conv2d(in_channels,
  118. out_channels,
  119. kernel_size=kernel_size,
  120. stride=stride,
  121. pad_mode=pad_mode,
  122. padding=padding,
  123. dilation=dilation,
  124. group=group,
  125. has_bias=has_bias,
  126. weight_init=weight_init,
  127. bias_init=bias_init)
  128. self.has_bn = Validator.check_bool(has_bn, "has_bn")
  129. self.has_act = activation is not None
  130. self.after_fake = after_fake
  131. if has_bn:
  132. self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
  133. if activation == "leakyrelu":
  134. self.activation = LeakyReLU(alpha)
  135. else:
  136. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  137. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  138. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  139. def construct(self, x):
  140. x = self.conv(x)
  141. if self.has_bn:
  142. x = self.batchnorm(x)
  143. if self.has_act:
  144. x = self.activation(x)
  145. return x
  146. class DenseBnAct(Cell):
  147. r"""
  148. A combination of Dense, Batchnorm, and the activation layer.
  149. This part is a more detailed overview of Dense op.
  150. Args:
  151. in_channels (int): The number of channels in the input space.
  152. out_channels (int): The number of channels in the output space.
  153. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  154. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  155. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  156. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  157. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  158. activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
  159. has_bn (bool): Specifies to use batchnorm or not. Default: False.
  160. activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
  161. 'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
  162. 'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None.
  163. after_fake(bool): Determin whether there must be a fake quantization operation after DenseBnAct.
  164. Inputs:
  165. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  166. Outputs:
  167. Tensor of shape :math:`(N, out\_channels)`.
  168. Examples:
  169. >>> net = nn.DenseBnAct(3, 4)
  170. >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
  171. >>> net(input)
  172. """
  173. def __init__(self,
  174. in_channels,
  175. out_channels,
  176. weight_init='normal',
  177. bias_init='zeros',
  178. has_bias=True,
  179. has_bn=False,
  180. activation=None,
  181. after_fake=True):
  182. super(DenseBnAct, self).__init__()
  183. self.dense = nn.Dense(
  184. in_channels,
  185. out_channels,
  186. weight_init,
  187. bias_init,
  188. has_bias)
  189. self.has_bn = Validator.check_bool(has_bn, "has_bn")
  190. self.has_act = activation is not None
  191. self.after_fake = after_fake
  192. if has_bn:
  193. self.batchnorm = BatchNorm1d(out_channels)
  194. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  195. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  196. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  197. def construct(self, x):
  198. x = self.dense(x)
  199. if self.has_bn:
  200. x = self.batchnorm(x)
  201. if self.has_act:
  202. x = self.activation(x)
  203. return x
  204. class BatchNormFoldCell(Cell):
  205. """
  206. Batch normalization folded.
  207. Args:
  208. momentum (float): Momentum value must be [0, 1]. Default: 0.9.
  209. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  210. float32 else 1e-3. Default: 1e-5.
  211. freeze_bn (int): Delay in steps at which computation switches from regular batch
  212. norm to frozen mean and std. Default: 0.
  213. Inputs:
  214. - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
  215. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  216. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  217. - **global_step** (Tensor) - Tensor to record current global step.
  218. Outputs:
  219. Tuple of 4 Tensor, the normalized input and the updated parameters.
  220. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  221. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  222. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  223. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  224. """
  225. def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
  226. """Initialize batch norm fold layer"""
  227. super(BatchNormFoldCell, self).__init__()
  228. self.epsilon = epsilon
  229. self.is_gpu = context.get_context('device_target') == "GPU"
  230. if self.is_gpu:
  231. self.bn_train = Q.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
  232. self.bn_infer = Q.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
  233. else:
  234. self.bn_reduce = P.BNTrainingReduce()
  235. self.bn_update = Q.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
  236. def construct(self, x, mean, variance, global_step):
  237. if self.is_gpu:
  238. if self.training:
  239. batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step)
  240. else:
  241. batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
  242. else:
  243. if self.training:
  244. x_sum, x_square_sum = self.bn_reduce(x)
  245. _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
  246. self.bn_update(x, x_sum, x_square_sum, mean, variance)
  247. P.Assign()(mean, mean_updated)
  248. P.Assign()(variance, variance_updated)
  249. else:
  250. batch_mean = P.ZerosLike()(variance)
  251. batch_std = P.OnesLike()(variance)
  252. running_mean = P.TensorAdd()(mean, 0.)
  253. running_std = P.Sqrt()(P.TensorAdd()(variance, self.epsilon))
  254. return batch_mean, batch_std, running_mean, running_std
  255. def _partial_init(cls_or_self, **kwargs):
  256. """
  257. Wrapper that allows creation of class factories.
  258. This can be useful when there is a need to create classes with the same
  259. constructor arguments, but different instances.
  260. Example::
  261. >>> Foo.partial_init = classmethod(_partial_init)
  262. >>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
  263. >>> foo_instance1 = foo_builder()
  264. >>> foo_instance2 = foo_builder()
  265. >>> id(foo_instance1) == id(foo_instance2)
  266. False
  267. """
  268. class _PartialWrapper:
  269. r"""
  270. class of wrapper that allows creation of class factories.
  271. """
  272. def __init__(self, p):
  273. self.p = p
  274. def __call__(self, *args, **keywords):
  275. return self.p(*args, **keywords)
  276. def __repr__(self):
  277. return self.p.__repr__()
  278. partial_init = _partial_init
  279. r = _PartialWrapper(partial(cls_or_self, **kwargs))
  280. return r
  281. class Observer(Cell):
  282. """
  283. Base class of Observer. Observer is used to calculate the statistics of specific layer.
  284. Notes:
  285. This class is an abstract class.
  286. Args:
  287. quant_dtype (QuantDtype): The type of FakeQuant data.
  288. """
  289. def __init__(self, quant_dtype):
  290. super(Observer, self).__init__()
  291. self.quant_dtype = quant_dtype
  292. def extend_repr(self):
  293. s = f"dtype={self.dtype}"
  294. return s
  295. def construct(self):
  296. pass
  297. partial_init = classmethod(_partial_init)
  298. class UniformQuantObserver(Observer):
  299. """
  300. The base class of Uniform Quantization Observer.
  301. Args:
  302. quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
  303. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  304. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  305. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  306. num_channels (int): declarate the min and max channel size, Default: 1.
  307. Returns:
  308. Tensor.
  309. """
  310. min_max_map = {
  311. QuantDtype.INT2: (-2, 1),
  312. QuantDtype.INT3: (-4, 3),
  313. QuantDtype.INT4: (-8, 7),
  314. QuantDtype.INT5: (-16, 15),
  315. QuantDtype.INT6: (-32, 31),
  316. QuantDtype.INT7: (-64, 63),
  317. QuantDtype.INT8: (-128, 127),
  318. QuantDtype.UINT2: (0, 3),
  319. QuantDtype.UINT3: (0, 7),
  320. QuantDtype.UINT4: (0, 15),
  321. QuantDtype.UINT5: (0, 31),
  322. QuantDtype.UINT6: (0, 63),
  323. QuantDtype.UINT7: (0, 127),
  324. QuantDtype.UINT8: (0, 255)
  325. }
  326. def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
  327. num_channels=1):
  328. super(UniformQuantObserver, self).__init__(quant_dtype)
  329. self.per_channel = per_channel
  330. self.symmetric = symmetric
  331. self.narrow_range = narrow_range
  332. self.num_channels = num_channels
  333. class FakeQuantWithMinMaxObserver(UniformQuantObserver):
  334. r"""
  335. Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
  336. Args:
  337. min_init (int, float): The initialized min value. Default: -6.
  338. max_init (int, float): The initialized max value. Default: 6.
  339. ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
  340. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  341. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  342. channel_axis (int): Quantization by channel axis. Default: 1.
  343. num_channels (int): declarate the min and max channel size, Default: 1.
  344. quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
  345. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  346. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  347. quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
  348. Inputs:
  349. - **x** (Tensor) - The input of FakeQuantWithMinMaxObserver.
  350. Outputs:
  351. Tensor, with the same type and shape as the `x`.
  352. Examples:
  353. >>> fake_quant = FakeQuantWithMinMaxObserver()
  354. >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  355. >>> result = fake_quant(input_x)
  356. """
  357. def __init__(self,
  358. min_init=-6,
  359. max_init=6,
  360. ema=False,
  361. ema_decay=0.999,
  362. per_channel=False,
  363. channel_axis=1,
  364. num_channels=1,
  365. quant_dtype=QuantDtype.INT8,
  366. symmetric=False,
  367. narrow_range=False,
  368. quant_delay=0):
  369. """Initialize FakeQuantWithMinMaxObserver"""
  370. super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
  371. symmetric=symmetric, narrow_range=narrow_range,
  372. num_channels=num_channels)
  373. Validator.check_type("min_init", min_init, [int, float])
  374. Validator.check_type("max_init", max_init, [int, float])
  375. Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
  376. Validator.check_non_negative_int(quant_delay, 'quant_delay')
  377. self.min_init = min_init
  378. self.max_init = max_init
  379. self.quant_dtype = quant_dtype
  380. self.ema = ema
  381. self.ema_decay = ema_decay
  382. self.per_channel = per_channel
  383. self.num_channels = num_channels
  384. self.channel_axis = channel_axis
  385. self.quant_delay = quant_delay
  386. self.symmetric = symmetric
  387. self.narrow_range = narrow_range
  388. self.is_ascend = context.get_context('device_target') == "Ascend"
  389. # init tensor min and max for fake quant op
  390. if self.per_channel:
  391. min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
  392. max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
  393. else:
  394. min_array = np.array([self.min_init]).astype(np.float32)
  395. max_array = np.array([self.max_init]).astype(np.float32)
  396. self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
  397. self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
  398. # init fake quant relative op
  399. if self.per_channel:
  400. quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
  401. ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
  402. else:
  403. quant_fun = Q.FakeQuantPerLayer
  404. ema_fun = Q.MinMaxUpdatePerLayer
  405. self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
  406. if self.is_ascend:
  407. self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
  408. symmetric=self.symmetric,
  409. narrow_range=self.narrow_range,
  410. quant_delay=self.quant_delay)
  411. self.fake_quant_infer = self.fake_quant_train
  412. else:
  413. quant_fun = partial(quant_fun,
  414. ema=self.ema,
  415. ema_decay=ema_decay,
  416. num_bits=self.quant_dtype.num_bits,
  417. symmetric=self.symmetric,
  418. narrow_range=self.narrow_range,
  419. quant_delay=self.quant_delay)
  420. self.fake_quant_train = quant_fun(training=True)
  421. self.fake_quant_infer = quant_fun(training=False)
  422. def extend_repr(self):
  423. s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
  424. 'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
  425. self.ema, self.ema_decay, self.per_channel,
  426. self.channel_axis, self.num_channels, self.quant_delay,
  427. self.min_init, self.max_init)
  428. return s
  429. def construct(self, x):
  430. if self.training:
  431. min_up, max_up = self.ema_update(x, self.minq, self.maxq)
  432. P.Assign()(self.minq, min_up)
  433. P.Assign()(self.maxq, max_up)
  434. out = self.fake_quant_train(x, self.minq, self.maxq)
  435. else:
  436. out = self.fake_quant_infer(x, self.minq, self.maxq)
  437. return out
  438. QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
  439. quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
  440. class Conv2dBnFoldQuant(Cell):
  441. r"""
  442. 2D convolution with BatchNormal op folded construct.
  443. This part is a more detailed overview of Conv2d op.
  444. Args:
  445. in_channels (int): The number of input channel :math:`C_{in}`.
  446. out_channels (int): The number of output channel :math:`C_{out}`.
  447. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  448. stride (int): Specifies stride for all spatial dimensions with the same value.
  449. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  450. padding (int): Implicit paddings on both sides of the input. Default: 0.
  451. eps (float): Parameters for BatchNormal. Default: 1e-5.
  452. momentum (float): Parameters for BatchNormal op. Default: 0.997.
  453. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  454. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  455. convolution kernel. Default: 'normal'.
  456. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  457. bias vector. Default: 'zeros'.
  458. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  459. beta vector. Default: 'zeros'.
  460. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  461. gamma vector. Default: 'ones'.
  462. mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  463. mean vector. Default: 'zeros'.
  464. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  465. variance vector. Default: 'ones'.
  466. fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
  467. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  468. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  469. freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000.
  470. Inputs:
  471. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  472. Outputs:
  473. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  474. Examples:
  475. >>> conv2d_bn = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid")
  476. >>> x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
  477. >>> y = conv2d_bn(x)
  478. """
  479. def __init__(self,
  480. in_channels,
  481. out_channels,
  482. kernel_size,
  483. stride=1,
  484. pad_mode='same',
  485. padding=0,
  486. dilation=1,
  487. group=1,
  488. eps=1e-5,
  489. momentum=0.997,
  490. has_bias=False,
  491. weight_init='normal',
  492. bias_init='zeros',
  493. beta_init='zeros',
  494. gamma_init='ones',
  495. mean_init='zeros',
  496. var_init='ones',
  497. fake=True,
  498. quant_config=quant_config_default,
  499. quant_dtype=QuantDtype.INT8,
  500. freeze_bn=100000):
  501. """Initialize Conv2dBnFoldQuant layer"""
  502. super(Conv2dBnFoldQuant, self).__init__()
  503. self.in_channels = in_channels
  504. self.out_channels = out_channels
  505. self.kernel_size = twice(kernel_size)
  506. self.stride = twice(stride)
  507. self.pad_mode = pad_mode
  508. self.padding = padding
  509. self.dilation = twice(dilation)
  510. self.group = group
  511. self.eps = eps
  512. self.momentum = momentum
  513. self.has_bias = has_bias
  514. self.freeze_bn = freeze_bn
  515. self.fake = fake
  516. self.quant_config = quant_config
  517. self.quant_dtype = quant_dtype
  518. self.is_gpu = context.get_context('device_target') == "GPU"
  519. # initialize convolution op and Parameter
  520. if context.get_context('device_target') == "Ascend" and group > 1:
  521. Validator.check_equal_int(group, in_channels, 'group')
  522. Validator.check_equal_int(group, out_channels, 'group')
  523. self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
  524. kernel_size=self.kernel_size,
  525. pad_mode=pad_mode,
  526. pad=padding,
  527. stride=self.stride,
  528. dilation=self.dilation)
  529. weight_shape = [1, in_channels, *self.kernel_size]
  530. channel_axis = 1
  531. else:
  532. self.conv = P.Conv2D(out_channel=out_channels,
  533. kernel_size=self.kernel_size,
  534. pad_mode=pad_mode,
  535. pad=padding,
  536. stride=self.stride,
  537. dilation=self.dilation,
  538. group=group)
  539. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  540. channel_axis = 0
  541. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  542. self.bias_add = P.BiasAdd()
  543. if Validator.check_bool(has_bias):
  544. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  545. else:
  546. self.bias = None
  547. # initialize BatchNorm Parameter
  548. self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
  549. self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
  550. self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
  551. self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
  552. requires_grad=False)
  553. # initialize fake ops
  554. self.fake_quant_weight = quant_config.weight(min_init=-6,
  555. max_init=6,
  556. ema=False,
  557. channel_axis=channel_axis,
  558. num_channels=out_channels,
  559. quant_dtype=quant_dtype)
  560. self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
  561. self.correct_mul = Q.CorrectionMul(channel_axis)
  562. if context.get_context('device_target') == "Ascend":
  563. self.batchnorm_fold2_train = Q.BatchNormFold2_D(freeze_bn=freeze_bn)
  564. self.batchnorm_fold2_infer = Q.BatchNormFold2_D(freeze_bn=0)
  565. elif context.get_context('device_target') == "GPU":
  566. self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
  567. self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
  568. else:
  569. raise ValueError("Unsupported platform: {}".format(context.get_context('device_target')))
  570. self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
  571. self.one = Tensor(1, mstype.int32)
  572. self.assignadd = P.AssignAdd()
  573. def extend_repr(self):
  574. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  575. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  576. 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
  577. self.kernel_size, self.stride,
  578. self.pad_mode, self.padding, self.dilation,
  579. self.group,
  580. self.fake, self.freeze_bn, self.momentum,
  581. self.fake_quant_weight.quant_delay)
  582. return s
  583. def construct(self, x):
  584. out_conv = self.conv(x, self.weight)
  585. if self.has_bias:
  586. out_conv = self.bias_add(out_conv, self.bias)
  587. # BN fold1
  588. batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
  589. self.moving_mean,
  590. self.moving_variance,
  591. self.step)
  592. # fake weight
  593. weight = self.correct_mul(self.weight, self.gamma, running_std)
  594. if self.fake:
  595. weight = self.fake_quant_weight(weight)
  596. out = self.conv(x, weight)
  597. if self.has_bias:
  598. out = self.bias_add(out, self.bias)
  599. # BN fold2
  600. if self.is_gpu:
  601. if self.training:
  602. out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
  603. batch_std, batch_mean, running_std, running_mean, self.step)
  604. F.control_depend(out, self.assignadd(self.step, self.one))
  605. else:
  606. out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
  607. batch_std, batch_mean, running_std, running_mean, self.step)
  608. else:
  609. if self.training:
  610. out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
  611. F.control_depend(out, self.assignadd(self.step, self.one))
  612. else:
  613. out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
  614. return out
  615. class Conv2dBnWithoutFoldQuant(Cell):
  616. r"""
  617. 2D convolution + batchnorm without fold with fake quant construct.
  618. This part is a more detailed overview of Conv2d op.
  619. Args:
  620. in_channels (int): The number of input channel :math:`C_{in}`.
  621. out_channels (int): The number of output channel :math:`C_{out}`.
  622. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  623. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
  624. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  625. padding (int): Implicit paddings on both sides of the input. Default: 0.
  626. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  627. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  628. divisible by the number of groups. Default: 1.
  629. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  630. eps (float): Parameters for BatchNormal. Default: 1e-5.
  631. momentum (float): Parameters for BatchNormal op. Default: 0.997.
  632. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  633. Default: 'normal'.
  634. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
  635. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  636. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  637. Inputs:
  638. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  639. Outputs:
  640. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  641. Examples:
  642. >>> conv2d_quant = nn.Conv2dBnWithoutFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid")
  643. >>> x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mstype.float32)
  644. >>> y = conv2d_quant(x)
  645. """
  646. def __init__(self,
  647. in_channels,
  648. out_channels,
  649. kernel_size,
  650. stride=1,
  651. pad_mode='same',
  652. padding=0,
  653. dilation=1,
  654. group=1,
  655. has_bias=False,
  656. eps=1e-5,
  657. momentum=0.997,
  658. weight_init='normal',
  659. bias_init='zeros',
  660. quant_config=quant_config_default,
  661. quant_dtype=QuantDtype.INT8):
  662. super(Conv2dBnWithoutFoldQuant, self).__init__()
  663. if isinstance(kernel_size, int):
  664. self.kernel_size = (kernel_size, kernel_size)
  665. else:
  666. self.kernel_size = kernel_size
  667. self.in_channels = Validator.check_positive_int(in_channels)
  668. self.out_channels = Validator.check_positive_int(out_channels)
  669. self.has_bias = has_bias
  670. self.stride = twice(stride)
  671. self.dilation = twice(dilation)
  672. self.pad_mode = pad_mode
  673. self.padding = padding
  674. self.group = group
  675. self.bias_add = P.BiasAdd()
  676. if Validator.check_bool(has_bias):
  677. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  678. else:
  679. self.bias = None
  680. # initialize convolution op and Parameter
  681. if context.get_context('device_target') == "Ascend" and group > 1:
  682. Validator.check_equal_int(group, in_channels, 'group')
  683. Validator.check_equal_int(group, out_channels, 'group')
  684. self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
  685. kernel_size=self.kernel_size,
  686. pad_mode=pad_mode,
  687. pad=padding,
  688. stride=self.stride,
  689. dilation=self.dilation)
  690. weight_shape = [1, in_channels, *self.kernel_size]
  691. channel_axis = 1
  692. else:
  693. self.conv = P.Conv2D(out_channel=self.out_channels,
  694. kernel_size=self.kernel_size,
  695. mode=1,
  696. pad_mode=self.pad_mode,
  697. pad=self.padding,
  698. stride=self.stride,
  699. dilation=self.dilation,
  700. group=self.group)
  701. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  702. channel_axis = 0
  703. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  704. self.fake_quant_weight = quant_config.weight(min_init=-6,
  705. max_init=6,
  706. ema=False,
  707. channel_axis=channel_axis,
  708. num_channels=out_channels,
  709. quant_dtype=quant_dtype)
  710. self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
  711. def construct(self, x):
  712. weight = self.fake_quant_weight(self.weight)
  713. out = self.conv(x, weight)
  714. if self.has_bias:
  715. out = self.bias_add(out, self.bias)
  716. out = self.batchnorm(out)
  717. return out
  718. def extend_repr(self):
  719. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  720. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  721. 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
  722. self.pad_mode, self.padding, self.dilation, self.group,
  723. self.has_bias, self.fake_quant_weight.quant_delay)
  724. return s
  725. class Conv2dQuant(Cell):
  726. r"""
  727. 2D convolution with fake quant op layer.
  728. This part is a more detailed overview of Conv2d op.
  729. Args:
  730. in_channels (int): The number of input channel :math:`C_{in}`.
  731. out_channels (int): The number of output channel :math:`C_{out}`.
  732. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  733. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
  734. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  735. padding (int): Implicit paddings on both sides of the input. Default: 0.
  736. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  737. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  738. divisible by the number of groups. Default: 1.
  739. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  740. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  741. Default: 'normal'.
  742. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
  743. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  744. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  745. Inputs:
  746. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  747. Outputs:
  748. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  749. Examples:
  750. >>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid")
  751. >>> x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
  752. >>> y = conv2d_quant(x)
  753. """
  754. def __init__(self,
  755. in_channels,
  756. out_channels,
  757. kernel_size,
  758. stride=1,
  759. pad_mode='same',
  760. padding=0,
  761. dilation=1,
  762. group=1,
  763. has_bias=False,
  764. weight_init='normal',
  765. bias_init='zeros',
  766. quant_config=quant_config_default,
  767. quant_dtype=QuantDtype.INT8):
  768. super(Conv2dQuant, self).__init__()
  769. if isinstance(kernel_size, int):
  770. self.kernel_size = (kernel_size, kernel_size)
  771. else:
  772. self.kernel_size = kernel_size
  773. self.in_channels = Validator.check_positive_int(in_channels)
  774. self.out_channels = Validator.check_positive_int(out_channels)
  775. self.has_bias = has_bias
  776. self.stride = twice(stride)
  777. self.dilation = twice(dilation)
  778. self.pad_mode = pad_mode
  779. self.padding = padding
  780. self.group = group
  781. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  782. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  783. self.bias_add = P.BiasAdd()
  784. if Validator.check_bool(has_bias):
  785. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  786. else:
  787. self.bias = None
  788. self.conv = P.Conv2D(out_channel=self.out_channels,
  789. kernel_size=self.kernel_size,
  790. mode=1,
  791. pad_mode=self.pad_mode,
  792. pad=self.padding,
  793. stride=self.stride,
  794. dilation=self.dilation,
  795. group=self.group)
  796. self.fake_quant_weight = quant_config.weight(min_init=-6,
  797. max_init=6,
  798. ema=False,
  799. channel_axis=0,
  800. num_channels=out_channels,
  801. quant_dtype=quant_dtype)
  802. def construct(self, x):
  803. weight = self.fake_quant_weight(self.weight)
  804. out = self.conv(x, weight)
  805. if self.has_bias:
  806. return self.bias_add(out, self.bias)
  807. return out
  808. def extend_repr(self):
  809. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  810. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  811. 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
  812. self.pad_mode, self.padding, self.dilation, self.group,
  813. self.has_bias, self.fake_quant_weight.quant_delay)
  814. return s
  815. class DenseQuant(Cell):
  816. r"""
  817. The fully connected layer with fake quant op.
  818. This part is a more detailed overview of Dense op.
  819. Args:
  820. in_channels (int): The dimension of the input space.
  821. out_channels (int): The dimension of the output space.
  822. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  823. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  824. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  825. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  826. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  827. activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
  828. eg. 'relu'. Default: None.
  829. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  830. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  831. Inputs:
  832. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  833. Outputs:
  834. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  835. Examples:
  836. >>> dense_quant = nn.DenseQuant(3, 6)
  837. >>> input_x = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
  838. >>> result = dense_quant(input_x)
  839. """
  840. def __init__(self,
  841. in_channels,
  842. out_channels,
  843. weight_init='normal',
  844. bias_init='zeros',
  845. has_bias=True,
  846. activation=None,
  847. quant_config=quant_config_default,
  848. quant_dtype=QuantDtype.INT8):
  849. super(DenseQuant, self).__init__()
  850. self.in_channels = Validator.check_positive_int(in_channels)
  851. self.out_channels = Validator.check_positive_int(out_channels)
  852. self.has_bias = Validator.check_bool(has_bias)
  853. if isinstance(weight_init, Tensor):
  854. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  855. weight_init.shape[1] != in_channels:
  856. raise ValueError("weight_init shape error")
  857. self.weight = Parameter(initializer(
  858. weight_init, [out_channels, in_channels]), name="weight")
  859. if self.has_bias:
  860. if isinstance(bias_init, Tensor):
  861. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  862. raise ValueError("bias_init shape error")
  863. self.bias = Parameter(initializer(
  864. bias_init, [out_channels]), name="bias")
  865. self.matmul = P.MatMul(transpose_b=True)
  866. self.bias_add = P.BiasAdd()
  867. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  868. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  869. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  870. self.activation_flag = self.activation is not None
  871. self.fake_quant_weight = quant_config.weight(min_init=-6,
  872. max_init=6,
  873. ema=False,
  874. channel_axis=0,
  875. num_channels=out_channels,
  876. quant_dtype=quant_dtype)
  877. def construct(self, x):
  878. """Use operators to construct the Dense layer."""
  879. output = self.fake_quant_weight(self.weight)
  880. output = self.matmul(x, output)
  881. if self.has_bias:
  882. output = self.bias_add(output, self.bias)
  883. if self.activation_flag:
  884. return self.activation(output)
  885. return output
  886. def extend_repr(self):
  887. """A pretty print for Dense layer."""
  888. str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
  889. self.in_channels, self.out_channels, self.weight, self.has_bias)
  890. if self.has_bias:
  891. str_info = str_info + ', bias={}'.format(self.bias)
  892. if self.activation_flag:
  893. str_info = str_info + ', activation={}'.format(self.activation)
  894. return str_info
  895. class _QuantActivation(Cell):
  896. r"""
  897. Base class for quantization aware training activation function. Add Fake Quant OP after activation OP.
  898. """
  899. def get_origin(self):
  900. raise NotImplementedError
  901. class ActQuant(_QuantActivation):
  902. r"""
  903. Quantization aware training activation function.
  904. Add the fake quant op to the end of activation op, by which the output of activation op will be truncated.
  905. Please check `FakeQuantWithMinMaxObserver` or other observer for more details.
  906. Args:
  907. activation (Cell): Activation cell class.
  908. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  909. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  910. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  911. Inputs:
  912. - **x** (Tensor) - The input of ReLU6Quant.
  913. Outputs:
  914. Tensor, with the same type and shape as the `x`.
  915. Examples:
  916. >>> act_quant = nn.ActQuant(nn.ReLU())
  917. >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
  918. >>> result = act_quant(input_x)
  919. """
  920. def __init__(self,
  921. activation,
  922. ema_decay=0.999,
  923. quant_config=quant_config_default,
  924. quant_dtype=QuantDtype.INT8):
  925. super(ActQuant, self).__init__()
  926. self.fake_quant_act = quant_config.activation(min_init=-6,
  927. max_init=6,
  928. ema=False,
  929. ema_decay=ema_decay,
  930. quant_dtype=quant_dtype)
  931. self.act = activation
  932. def construct(self, x):
  933. x = self.act(x)
  934. x = self.fake_quant_act(x)
  935. return x
  936. def get_origin(self):
  937. return self.act
  938. class LeakyReLUQuant(_QuantActivation):
  939. r"""
  940. LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
  941. This part is a more detailed overview of HSwish op.
  942. Args:
  943. activation (Cell): Activation cell class.
  944. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  945. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  946. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  947. Inputs:
  948. - **x** (Tensor) - The input of LeakyReLUQuant.
  949. Outputs:
  950. Tensor, with the same type and shape as the `x`.
  951. Examples:
  952. >>> activation = nn.LeakyReLUQuant(nn.LeakyReLU())
  953. >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  954. >>> result = activation(input)
  955. """
  956. def __init__(self,
  957. activation,
  958. ema_decay=0.999,
  959. quant_config=quant_config_default,
  960. quant_dtype=QuantDtype.INT8):
  961. super(LeakyReLUQuant, self).__init__()
  962. self.fake_quant_act_before = quant_config.activation(min_init=-6,
  963. max_init=6,
  964. ema=True,
  965. ema_decay=ema_decay,
  966. quant_dtype=quant_dtype)
  967. self.fake_quant_act_after = quant_config.activation(min_init=-6,
  968. max_init=6,
  969. ema=True,
  970. ema_decay=ema_decay,
  971. quant_dtype=quant_dtype)
  972. if issubclass(activation.__class__, nn.LeakyReLU):
  973. self.act = activation
  974. else:
  975. raise ValueError("Activation should be `nn.LeakyReLU`")
  976. def construct(self, x):
  977. x = self.fake_quant_act_before(x)
  978. x = self.act(x)
  979. x = self.fake_quant_act_after(x)
  980. return x
  981. def get_origin(self):
  982. return self.act
  983. class HSwishQuant(_QuantActivation):
  984. r"""
  985. HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
  986. This part is a more detailed overview of HSwish op.
  987. Args:
  988. activation (Cell): Activation cell class.
  989. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  990. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  991. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  992. Inputs:
  993. - **x** (Tensor) - The input of HSwishQuant.
  994. Outputs:
  995. Tensor, with the same type and shape as the `x`.
  996. Examples:
  997. >>> activation = nn.HSwishQuant(nn.HSwish())
  998. >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  999. >>> result = activation(input)
  1000. """
  1001. def __init__(self,
  1002. activation,
  1003. ema_decay=0.999,
  1004. quant_config=quant_config_default,
  1005. quant_dtype=QuantDtype.INT8):
  1006. super(HSwishQuant, self).__init__()
  1007. self.fake_quant_act_before = quant_config.activation(min_init=-6,
  1008. max_init=6,
  1009. ema=True,
  1010. ema_decay=ema_decay,
  1011. quant_dtype=quant_dtype)
  1012. self.fake_quant_act_after = quant_config.activation(min_init=-6,
  1013. max_init=6,
  1014. ema=True,
  1015. ema_decay=ema_decay,
  1016. quant_dtype=quant_dtype)
  1017. if issubclass(activation.__class__, nn.HSwish):
  1018. self.act = activation
  1019. else:
  1020. raise ValueError("Activation should be `nn.HSwish`")
  1021. def construct(self, x):
  1022. x = self.fake_quant_act_before(x)
  1023. x = self.act(x)
  1024. x = self.fake_quant_act_after(x)
  1025. return x
  1026. def get_origin(self):
  1027. return self.act
  1028. class HSigmoidQuant(_QuantActivation):
  1029. r"""
  1030. HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
  1031. This part is a more detailed overview of HSigmoid op.
  1032. Args:
  1033. activation (Cell): Activation cell class.
  1034. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1035. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  1036. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1037. Inputs:
  1038. - **x** (Tensor) - The input of HSigmoidQuant.
  1039. Outputs:
  1040. Tensor, with the same type and shape as the `x`.
  1041. Examples:
  1042. >>> activation = nn.HSigmoidQuant(nn.HSigmoid())
  1043. >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  1044. >>> result = activation(input)
  1045. """
  1046. def __init__(self,
  1047. activation,
  1048. ema_decay=0.999,
  1049. quant_config=quant_config_default,
  1050. quant_dtype=QuantDtype.INT8):
  1051. super(HSigmoidQuant, self).__init__()
  1052. self.fake_quant_act_before = quant_config.activation(min_init=-6,
  1053. max_init=6,
  1054. ema=True,
  1055. ema_decay=ema_decay,
  1056. quant_dtype=quant_dtype)
  1057. self.fake_quant_act_after = quant_config.activation(min_init=-6,
  1058. max_init=6,
  1059. ema=True,
  1060. ema_decay=ema_decay,
  1061. quant_dtype=quant_dtype)
  1062. if issubclass(activation.__class__, nn.HSigmoid):
  1063. self.act = activation
  1064. else:
  1065. raise ValueError("Activation should be `nn.HSigmoid`")
  1066. def construct(self, x):
  1067. x = self.fake_quant_act_before(x)
  1068. x = self.act(x)
  1069. x = self.fake_quant_act_after(x)
  1070. return x
  1071. def get_origin(self):
  1072. return self.act
  1073. class TensorAddQuant(Cell):
  1074. r"""
  1075. Add Fake Quant OP after TensorAdd OP.
  1076. This part is a more detailed overview of TensorAdd op.
  1077. Args:
  1078. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1079. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  1080. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1081. Inputs:
  1082. - **x** (Tensor) - The input of TensorAddQuant.
  1083. Outputs:
  1084. Tensor, with the same type and shape as the `x`.
  1085. Examples:
  1086. >>> add_quant = nn.TensorAddQuant()
  1087. >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  1088. >>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
  1089. >>> result = add_quant(input_x, input_y)
  1090. """
  1091. def __init__(self,
  1092. ema_decay=0.999,
  1093. quant_config=quant_config_default,
  1094. quant_dtype=QuantDtype.INT8):
  1095. super(TensorAddQuant, self).__init__()
  1096. self.fake_quant_act = quant_config.activation(min_init=-6,
  1097. max_init=6,
  1098. ema=True,
  1099. ema_decay=ema_decay,
  1100. quant_dtype=quant_dtype)
  1101. self.add = P.TensorAdd()
  1102. def construct(self, x1, x2):
  1103. x = self.add(x1, x2)
  1104. x = self.fake_quant_act(x)
  1105. return x
  1106. class MulQuant(Cell):
  1107. r"""
  1108. Add Fake Quant OP after Mul OP.
  1109. This part is a more detailed overview of Mul op.
  1110. Args:
  1111. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1112. quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
  1113. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1114. Inputs:
  1115. - **x** (Tensor) - The input of MulQuant.
  1116. Outputs:
  1117. Tensor, with the same type and shape as the `x`.
  1118. Examples:
  1119. >>> mul_quant = nn.MulQuant()
  1120. >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  1121. >>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
  1122. >>> result = mul_quant(input_x, input_y)
  1123. """
  1124. def __init__(self,
  1125. ema_decay=0.999,
  1126. quant_config=quant_config_default,
  1127. quant_dtype=QuantDtype.INT8):
  1128. super(MulQuant, self).__init__()
  1129. self.fake_quant_act = quant_config.activation(min_init=-6,
  1130. max_init=6,
  1131. ema=True,
  1132. ema_decay=ema_decay,
  1133. quant_dtype=quant_dtype)
  1134. self.mul = P.Mul()
  1135. def construct(self, x1, x2):
  1136. x = self.mul(x1, x2)
  1137. x = self.fake_quant_act(x)
  1138. return x
  1139. class QuantBlock(Cell):
  1140. r"""
  1141. A quant block of Conv/Dense, activation layer for Ascend deploy.
  1142. Calculate Conv or Dense in Int8, with Quant and DeQuant.
  1143. Notes:
  1144. This block is only for deploy, and not trainable.
  1145. Args:
  1146. in_channels (int): The number of channels in the input space.
  1147. out_channels (int): The number of channels in the output space.
  1148. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  1149. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  1150. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  1151. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  1152. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  1153. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
  1154. batchnorm (bool): Specifies to used batchnorm or not. Default: None.
  1155. activation (string): Specifies activation type. The optional values are as following:
  1156. 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
  1157. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
  1158. Inputs:
  1159. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  1160. Outputs:
  1161. Tensor of shape :math:`(N, out\_channels)`.
  1162. """
  1163. def __init__(self,
  1164. core_op,
  1165. weight,
  1166. quant_op,
  1167. dequant_op,
  1168. dequant_scale,
  1169. bias=None,
  1170. activation=None):
  1171. super(QuantBlock, self).__init__()
  1172. self.core_op = core_op
  1173. self.weight = weight
  1174. self.quant = quant_op
  1175. self.dequant = dequant_op
  1176. self.dequant_scale = dequant_scale
  1177. self.bias = bias
  1178. self.has_bias = bias is not None
  1179. self.activation = activation
  1180. self.has_act = activation is not None
  1181. self.bias_add = P.BiasAdd()
  1182. def construct(self, x):
  1183. x = self.quant(x)
  1184. if self.has_bias:
  1185. x = self.core_op(x, self.weight)
  1186. x = self.bias_add(x, self.bias)
  1187. else:
  1188. x = self.core_op(x, self.weight)
  1189. x = self.dequant(x, self.dequant_scale)
  1190. x = F.cast(x, mstype.float32)
  1191. if self.has_act:
  1192. x = self.activation(x)
  1193. return x
  1194. def extend_repr(self):
  1195. str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
  1196. if self.has_bias:
  1197. str_info = str_info + f', bias=shape[{self.bias.shape}]'
  1198. if self.has_act:
  1199. str_info = str_info + f', activation={self.activation}'
  1200. str_info = str_info + f', dequant={self.dequant}'
  1201. return str_info
  1202. class QuantMindirBlock(Cell):
  1203. """A quant binary block of Conv/Dense, activation layer for export MINDIR model.
  1204. Args:
  1205. core_op (Cell): The operation cell.
  1206. weight (Tensor): The weigth of the cell.
  1207. bias (Tensor): The bias of the cell. Default: None.
  1208. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
  1209. param_dict (dict): The information of the cell.
  1210. """
  1211. def __init__(self,
  1212. core_op,
  1213. weight,
  1214. bias=None,
  1215. activation=None,
  1216. param_dict=None):
  1217. super(QuantMindirBlock, self).__init__()
  1218. self.core_op = core_op
  1219. if activation is not None:
  1220. self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
  1221. self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
  1222. self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
  1223. if param_dict["output_maxq"] is not None:
  1224. self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
  1225. self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
  1226. self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
  1227. if hasattr(core_op, 'pad_mode'):
  1228. self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
  1229. self.core_op.add_prim_attr("num_bits", Tensor(8))
  1230. self.core_op.add_prim_attr("narrow_range", Tensor(False))
  1231. if param_dict["input_maxq"] == 'None':
  1232. self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
  1233. self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
  1234. elif param_dict["input_maxq"] is not None:
  1235. self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
  1236. self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
  1237. self.weight = weight
  1238. self.bias = bias
  1239. self.has_bias = bias is not None
  1240. self.activation = activation
  1241. self.has_act = activation is not None
  1242. self.bias_add = P.BiasAdd()
  1243. if isinstance(activation, ReLU):
  1244. self.activation = None
  1245. self.has_act = False
  1246. def construct(self, x):
  1247. if self.has_bias:
  1248. x = self.core_op(x, self.weight)
  1249. x = self.bias_add(x, self.bias)
  1250. else:
  1251. x = self.core_op(x, self.weight)
  1252. return x
  1253. def extend_repr(self):
  1254. str_info = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
  1255. if self.has_bias:
  1256. str_info = str_info + f', bias=shape[{self.bias.shape}]'
  1257. if self.has_act:
  1258. str_info = str_info + f', activation={self.activation}'
  1259. return str_info