You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quant.py 73 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Quantization aware training."""
  16. from functools import partial
  17. from collections import namedtuple
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. from mindspore.ops.primitive import Primitive
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import functional as F
  23. from mindspore.common.parameter import Parameter
  24. from mindspore.common.initializer import initializer
  25. from mindspore.common.tensor import Tensor
  26. from mindspore._checkparam import Validator, Rel, twice
  27. from mindspore.compression.common import QuantDtype
  28. import mindspore.context as context
  29. from .normalization import BatchNorm2d
  30. from .activation import get_activation, ReLU
  31. from ..cell import Cell
  32. from ... import nn
  33. from ...ops.operations import _quant_ops as Q
  34. __all__ = [
  35. 'FakeQuantWithMinMaxObserver',
  36. 'Conv2dBnFoldQuantOneConv',
  37. 'Conv2dBnFoldQuant',
  38. 'Conv2dBnWithoutFoldQuant',
  39. 'Conv2dQuant',
  40. 'DenseQuant',
  41. 'ActQuant',
  42. 'TensorAddQuant',
  43. 'MulQuant',
  44. ]
  45. class BatchNormFoldCell(Cell):
  46. """
  47. Batch Normalization folded.
  48. Args:
  49. momentum (float): Momentum value must be [0, 1]. Default: 0.9.
  50. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  51. float32 else 1e-3. Default: 1e-5.
  52. freeze_bn (int): Delay in steps at which computation switches from regular batch
  53. norm to frozen mean and std. Default: 0.
  54. Inputs:
  55. - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
  56. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  57. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  58. - **global_step** (Tensor) - Tensor to record current global step.
  59. Outputs:
  60. Tuple of 4 Tensor, the normalized input and the updated parameters.
  61. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  62. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  63. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  64. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  65. """
  66. def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
  67. """Initialize batch norm fold layer"""
  68. super(BatchNormFoldCell, self).__init__()
  69. self.epsilon = epsilon
  70. self.is_gpu = context.get_context('device_target') == "GPU"
  71. if self.is_gpu:
  72. self.bn_train = Q.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
  73. self.bn_infer = Q.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
  74. else:
  75. self.bn_reduce = P.BNTrainingReduce()
  76. self.bn_update = Q.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
  77. def construct(self, x, mean, variance, global_step):
  78. if self.is_gpu:
  79. if self.training:
  80. batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step)
  81. else:
  82. batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
  83. else:
  84. if self.training:
  85. x_sum, x_square_sum = self.bn_reduce(x)
  86. _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
  87. self.bn_update(x, x_sum, x_square_sum, mean, variance)
  88. P.Assign()(mean, mean_updated)
  89. P.Assign()(variance, variance_updated)
  90. else:
  91. batch_mean = P.ZerosLike()(variance)
  92. batch_std = P.OnesLike()(variance)
  93. running_mean = P.Add()(mean, 0.)
  94. running_std = P.Sqrt()(P.Add()(variance, self.epsilon))
  95. return batch_mean, batch_std, running_mean, running_std
  96. def _partial_init(cls_or_self, **kwargs):
  97. """
  98. Wrapper that allows creation of class factories.
  99. This can be useful when there is a need to create classes with the same
  100. constructor arguments, but different instances.
  101. Examples:
  102. >>> class Foo:
  103. ... def __init__(self, a, b, answer):
  104. ... pass
  105. >>> Foo.partial_init = classmethod(_partial_init)
  106. >>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
  107. >>> foo_instance1 = foo_builder()
  108. >>> foo_instance2 = foo_builder()
  109. >>> result = (id(foo_instance1) == id(foo_instance2))
  110. >>> print(result)
  111. False
  112. """
  113. class _PartialWrapper:
  114. r"""
  115. class of wrapper that allows creation of class factories.
  116. """
  117. def __init__(self, p):
  118. self.p = p
  119. def __call__(self, *args, **keywords):
  120. return self.p(*args, **keywords)
  121. def __repr__(self):
  122. return self.p.__repr__()
  123. partial_init = _partial_init
  124. r = _PartialWrapper(partial(cls_or_self, **kwargs))
  125. return r
  126. class _Observer(Cell):
  127. """
  128. Base class of Observer. Observer is used to calculate the statistics of specific layer.
  129. Notes:
  130. This class is an abstract class.
  131. Args:
  132. quant_dtype (QuantDtype): The type of FakeQuant data.
  133. """
  134. def __init__(self, quant_dtype):
  135. super(_Observer, self).__init__()
  136. self.quant_dtype = quant_dtype
  137. def extend_repr(self):
  138. s = f"quant_dtype={self.quant_dtype}"
  139. return s
  140. def construct(self):
  141. pass
  142. partial_init = classmethod(_partial_init)
  143. class UniformQuantObserver(_Observer):
  144. """
  145. The base class of Uniform Quantization Observer.
  146. Args:
  147. quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
  148. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  149. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  150. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  151. num_channels (int): declarate the min and max channel size, Default: 1.
  152. Returns:
  153. Tensor.
  154. """
  155. min_max_map = {
  156. QuantDtype.INT2: (-2, 1),
  157. QuantDtype.INT3: (-4, 3),
  158. QuantDtype.INT4: (-8, 7),
  159. QuantDtype.INT5: (-16, 15),
  160. QuantDtype.INT6: (-32, 31),
  161. QuantDtype.INT7: (-64, 63),
  162. QuantDtype.INT8: (-128, 127),
  163. QuantDtype.UINT2: (0, 3),
  164. QuantDtype.UINT3: (0, 7),
  165. QuantDtype.UINT4: (0, 15),
  166. QuantDtype.UINT5: (0, 31),
  167. QuantDtype.UINT6: (0, 63),
  168. QuantDtype.UINT7: (0, 127),
  169. QuantDtype.UINT8: (0, 255)
  170. }
  171. def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
  172. num_channels=1):
  173. super(UniformQuantObserver, self).__init__(quant_dtype)
  174. self.per_channel = per_channel
  175. self.symmetric = symmetric
  176. self.narrow_range = narrow_range
  177. self.num_channels = num_channels
  178. class FakeQuantWithMinMaxObserver(UniformQuantObserver):
  179. r"""
  180. Quantization aware operation which provides the fake quantization observer function on data with min and max.
  181. The detail of the quantization mode `DEFAULT` is described as below:
  182. The running min/max :math:`x_{min}` and :math:`x_{max}` are computed as:
  183. .. math::
  184. \begin{array}{ll} \\
  185. x_{min} =
  186. \begin{cases}
  187. \min(\min(X), 0)
  188. & \text{ if } ema = \text{False} \\
  189. \min((1 - c) \min(X) + \text{c } x_{min}, 0)
  190. & \text{ if } \text{otherwise}
  191. \end{cases}\\
  192. x_{max} =
  193. \begin{cases}
  194. \max(\max(X), 0)
  195. & \text{ if } ema = \text{False} \\
  196. \max((1 - c) \max(X) + \text{c } x_{max}, 0)
  197. & \text{ if } \text{otherwise}
  198. \end{cases}
  199. \end{array}
  200. where X is the input tensor, and :math:`c` is the `ema_decay`.
  201. The scale and zero point zp is computed as:
  202. .. math::
  203. \begin{array}{ll} \\
  204. scale =
  205. \begin{cases}
  206. \frac{x_{max} - x_{min}}{Q_{max} - Q_{min}}
  207. & \text{ if } symmetric = \text{False} \\
  208. \frac{2\max(x_{max}, \left | x_{min} \right |) }{Q_{max} - Q_{min}}
  209. & \text{ if } \text{otherwise}
  210. \end{cases}\\
  211. zp\_min = Q_{min} - \frac{x_{min}}{scale} \\
  212. zp = \left \lfloor \min(Q_{max}, \max(Q_{min}, zp\_min)) + 0.5 \right \rfloor
  213. \end{array}
  214. where :math:`Q_{max}` and :math:`Q_{min}` is decided by quant_dtype, for example, if quant_dtype=INT8,
  215. then :math:`Q_{max} = 127` and :math:`Q_{min} = -128`.
  216. The fake quant output is computed as:
  217. .. math::
  218. \begin{array}{ll} \\
  219. u_{min} = (Q_{min} - zp) * scale \\
  220. u_{max} = (Q_{max} - zp) * scale \\
  221. u_X = \left \lfloor \frac{\min(u_{max}, \max(u_{min}, X)) - u_{min}}{scale}
  222. + 0.5 \right \rfloor \\
  223. output = u_X * scale + u_{min}
  224. \end{array}
  225. The detail of the quantization mode `LEARNED_SCALE` is described as below:
  226. The fake quant output is computed as:
  227. .. math::
  228. \bar{X}=\left\{\begin{matrix}
  229. clip\left ( \frac{X}{maxq},0,1\right ) \qquad \quad if\quad neg\_trunc\\
  230. clip\left ( \frac{X}{maxq},-1,1\right )\qquad \ if\quad otherwise
  231. \end{matrix}\right. \\
  232. output=\frac{floor\left ( \bar{X}\ast Q_{max}+0.5 \right ) \ast scale }{Q_{max}}
  233. where X is the input tensor.
  234. where :math:`Q_{max}` (quant_max) is decided by quant_dtype and neg_trunc, for example, if quant_dtype=INT8
  235. and neg_trunc works, :math:`Q_{max} = 256` , otherwise math:`Q_{max} = 127`.
  236. The maxq is updated by training, and its gradient is calculated as follows:
  237. .. math::
  238. \frac{\partial \ output}{\partial \ maxq} & = \left\{\begin{matrix}
  239. -\frac{X}{maxq}+\left \lfloor \frac{X}{maxq} \right \rceil \qquad if\quad bound_{lower}< \frac{X}{maxq}< 1\\
  240. -1 \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\le bound_{lower}\\
  241. 1 \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\ge 1 \qquad \quad
  242. \end{matrix}\right. \\
  243. bound_{lower}=
  244. \end{align}\left\{\begin{matrix}
  245. 0\qquad \quad if\quad neg\_trunc\\
  246. -1\qquad if\quad otherwise
  247. \end{matrix}\right.
  248. Then minq is computed as:
  249. .. math::
  250. minq=\left\{\begin{matrix}
  251. 0 \qquad \qquad \quad if\quad neg\_trunc\\
  252. -maxq\qquad if\quad otherwise
  253. \end{matrix}\right.
  254. When exporting, the scale and zero point zp is computed as:
  255. .. math::
  256. scale=\frac{maxq}{quant\_max} ,\quad zp=0 \\
  257. zp is equal to 0 consistently, due to the LEARNED_SCALE`s symmetric nature.
  258. Args:
  259. min_init (int, float, list): The initialized min value. Default: -6.
  260. max_init (int, float, list): The initialized max value. Default: 6.
  261. ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
  262. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  263. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
  264. channel_axis (int): Quantization by channel axis. Default: 1.
  265. num_channels (int): declarate the min and max channel size, Default: 1.
  266. quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
  267. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  268. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  269. quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
  270. neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False.
  271. mode (string): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported.
  272. Default: ("DEFAULT")
  273. Inputs:
  274. - **input** (Tensor) - The input of FakeQuantWithMinMaxObserver.
  275. Outputs:
  276. Tensor, with the same type and shape as the `input`.
  277. Raises:
  278. TypeError: If `min_init` or `max_init` is not int, float or list.
  279. TypeError: If `quant_delay` is not an int.
  280. TypeError: If `min_init` is not less than `max_init`.
  281. TypeError: If `quant_delay` is not greater than or equal to 0.
  282. Supported Platforms:
  283. ``Ascend`` ``GPU``
  284. Examples:
  285. >>> fake_quant = nn.FakeQuantWithMinMaxObserver()
  286. >>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  287. >>> output = fake_quant(input)
  288. >>> print(output)
  289. [[ 0.9882355 1.9764705 0.9882355]
  290. [-1.9764705 0. -0.9882355]]
  291. """
  292. def __init__(self,
  293. min_init=-6,
  294. max_init=6,
  295. ema=False,
  296. ema_decay=0.999,
  297. per_channel=False,
  298. channel_axis=1,
  299. num_channels=1,
  300. quant_dtype=QuantDtype.INT8,
  301. symmetric=False,
  302. narrow_range=False,
  303. quant_delay=0,
  304. neg_trunc=False,
  305. mode="DEFAULT"):
  306. """Initialize FakeQuantWithMinMaxObserver"""
  307. super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
  308. symmetric=symmetric, narrow_range=narrow_range,
  309. num_channels=num_channels)
  310. Validator.check_value_type("min_init", min_init, [int, float, list], type(self).__name__)
  311. Validator.check_value_type("max_init", max_init, [int, float, list], type(self).__name__)
  312. if isinstance(max_init, (int, float)) and isinstance(min_init, (int, float)):
  313. Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
  314. elif not np.greater(max_init, min_init).all():
  315. raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.")
  316. Validator.check_non_negative_int(quant_delay, 'quant_delay')
  317. self.min_init = min_init
  318. self.max_init = max_init
  319. self.quant_dtype = quant_dtype
  320. self.num_bits = quant_dtype.num_bits
  321. self.ema = ema
  322. self.ema_decay = ema_decay
  323. self.per_channel = per_channel
  324. self.num_channels = num_channels
  325. self.channel_axis = channel_axis
  326. self.quant_delay = quant_delay
  327. self.symmetric = symmetric
  328. self.narrow_range = narrow_range
  329. self.neg_trunc = neg_trunc
  330. self.mode = mode
  331. self.is_ascend = context.get_context('device_target') == "Ascend"
  332. self.Neg = P.Neg()
  333. min_array = self._get_init_array(self.min_init)
  334. max_array = self._get_init_array(self.max_init)
  335. if self.mode == "DEFAULT":
  336. # init tensor min and max for fake quantized operation
  337. self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
  338. self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
  339. # init fake quant relative op
  340. if self.per_channel:
  341. quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
  342. ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
  343. else:
  344. quant_fun = Q.FakeQuantPerLayer
  345. ema_fun = Q.MinMaxUpdatePerLayer
  346. self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
  347. if self.is_ascend:
  348. self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
  349. symmetric=self.symmetric,
  350. narrow_range=self.narrow_range,
  351. quant_delay=self.quant_delay)
  352. self.fake_quant_infer = self.fake_quant_train
  353. else:
  354. quant_fun = partial(quant_fun,
  355. ema=self.ema,
  356. ema_decay=ema_decay,
  357. num_bits=self.quant_dtype.num_bits,
  358. symmetric=self.symmetric,
  359. narrow_range=self.narrow_range,
  360. quant_delay=self.quant_delay)
  361. self.fake_quant_train = quant_fun(training=True)
  362. self.fake_quant_infer = quant_fun(training=False)
  363. elif self.mode == "LEARNED_SCALE":
  364. if not self.symmetric:
  365. raise ValueError("The 'LEARNED_SCALE' mode only support symmetric quant, please set symmetric to True.")
  366. if self.neg_trunc:
  367. min_array = self._get_init_array(0)
  368. self.narrow_range = False
  369. elif not self.narrow_range:
  370. raise ValueError("The 'LEARNED_SCALE' mode only support narrow_range=True config, "
  371. "except for neg_trunc=True scenario.")
  372. self._calculate_quant_max()
  373. self.minq = Parameter(Tensor(min_array), name='minq')
  374. self.maxq = Parameter(Tensor(max_array), name='maxq')
  375. self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32)),
  376. name="quant_max", requires_grad=False)
  377. # init fake quant relative op
  378. if self.per_channel:
  379. quant_fun = partial(Q.FakeLearnedScaleQuantPerChannel, channel_axis=self.channel_axis)
  380. else:
  381. quant_fun = Q.FakeLearnedScaleQuantPerLayer
  382. quant_fun = partial(quant_fun,
  383. quant_delay=self.quant_delay,
  384. neg_trunc=self.neg_trunc)
  385. self.fake_quant_train = quant_fun(training=True)
  386. self.fake_quant_infer = quant_fun(training=False)
  387. else:
  388. raise ValueError("Invalid mode, currently only valid for `DEFAULT` and `LEARNED_SCALE` mode.")
  389. def reset(self, quant_dtype=QuantDtype.INT8, min_init=-6, max_init=6):
  390. r"""
  391. Reset the quant max parameter (eg. 256) and the initial value of the minq parameter and maxq parameter,
  392. this function is currently only valid for `LEARNED_SCALE` mode.
  393. """
  394. if self.mode == "LEARNED_SCALE":
  395. self.quant_dtype = quant_dtype
  396. self.num_bits = quant_dtype.num_bits
  397. self._calculate_quant_max()
  398. if self.neg_trunc:
  399. min_init = 0
  400. self.min_init = min_init
  401. self.max_init = max_init
  402. min_array = self._get_init_array(self.min_init)
  403. max_array = self._get_init_array(self.max_init)
  404. self.minq.set_data(Tensor(min_array))
  405. self.maxq.set_data(Tensor(max_array))
  406. self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32)))
  407. else:
  408. raise ValueError("The `reset` function is currently only valid for `LEARNED_SCALE` mode.")
  409. def _get_init_array(self, init_date):
  410. """
  411. Convert the initial value to array.
  412. """
  413. if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels:
  414. raise ValueError("The length of the min_init/max_init list shuold be equal to num_channels for "
  415. "perchannel quant scenario, but get {}".format(len(init_date)))
  416. if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1:
  417. raise ValueError("The length of the min_init/max_init list shuold be 1 for perlayer quant "
  418. "scenario, but get {}".format(len(init_date)))
  419. if isinstance(init_date, list):
  420. min_max_array = np.array(init_date).astype(np.float32)
  421. elif self.per_channel and not isinstance(init_date, list):
  422. min_max_array = np.array([init_date] * self.num_channels).astype(np.float32)
  423. else:
  424. min_max_array = np.array([init_date]).astype(np.float32)
  425. return min_max_array
  426. def _calculate_quant_max(self):
  427. """
  428. The quantization range is calculated according to num_bits.
  429. """
  430. if not self.neg_trunc:
  431. self._quant_max = (1 << (self.num_bits - 1)) - 1
  432. else:
  433. self._quant_max = (1 << self.num_bits) - 1
  434. def extend_repr(self):
  435. s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
  436. 'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
  437. self.ema, self.ema_decay, self.per_channel,
  438. self.channel_axis, self.num_channels, self.quant_delay,
  439. self.min_init, self.max_init)
  440. return s
  441. def construct(self, x):
  442. if self.mode == "LEARNED_SCALE":
  443. if self.training:
  444. out = self.fake_quant_train(x, self.maxq, self.quant_max)
  445. if not self.neg_trunc:
  446. self.minq = self.Neg(self.maxq)
  447. else:
  448. out = self.fake_quant_infer(x, self.maxq, self.quant_max)
  449. else:
  450. if self.training:
  451. min_up, max_up = self.ema_update(x, self.minq, self.maxq)
  452. self.minq = min_up
  453. self.maxq = max_up
  454. out = self.fake_quant_train(x, self.minq, self.maxq)
  455. else:
  456. out = self.fake_quant_infer(x, self.minq, self.maxq)
  457. return out
  458. QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
  459. quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
  460. class Conv2dBnFoldQuantOneConv(Cell):
  461. r"""
  462. 2D convolution which use the convolution layer statistics once to calculate Batch Normalization
  463. operation folded construct.
  464. This part is a more detailed overview of Conv2d operation. For more detials about Quantilization,
  465. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  466. Args:
  467. in_channels (int): The number of input channel :math:`C_{in}`.
  468. out_channels (int): The number of output channel :math:`C_{out}`.
  469. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  470. stride (int): Specifies stride for all spatial dimensions with the same value.
  471. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  472. padding (int): Implicit paddings on both sides of the input. Default: 0.
  473. eps (float): Parameters for Batch Normalization. Default: 1e-5.
  474. momentum (float): Parameters for Batch Normalization op. Default: 0.997.
  475. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  476. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  477. divisible by the number of groups. Default: 1.
  478. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  479. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  480. convolution kernel. Default: 'normal'.
  481. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  482. bias vector. Default: 'zeros'.
  483. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  484. beta vector. Default: 'zeros'.
  485. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  486. gamma vector. Default: 'ones'.
  487. mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  488. mean vector. Default: 'zeros'.
  489. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  490. variance vector. Default: 'ones'.
  491. fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
  492. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  493. generated by compression.quant.create_quant_config method.
  494. Default: both set to default FakeQuantWithMinMaxObserver.
  495. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  496. Inputs:
  497. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  498. Outputs:
  499. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  500. Raises:
  501. TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
  502. TypeError: If `has_bias` is not a bool.
  503. ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
  504. ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
  505. Supported Platforms:
  506. ``Ascend`` ``GPU``
  507. Examples:
  508. >>> qconfig = compression.quant.create_quant_config()
  509. >>> conv2d_bnfold = nn.Conv2dBnFoldQuantOneConv(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
  510. ... quant_config=qconfig)
  511. >>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
  512. >>> result = conv2d_bnfold(input)
  513. >>> output = result.shape
  514. >>> print(output)
  515. (2, 6, 2, 2)
  516. """
  517. def __init__(self,
  518. in_channels,
  519. out_channels,
  520. kernel_size,
  521. stride=1,
  522. pad_mode='same',
  523. padding=0,
  524. dilation=1,
  525. group=1,
  526. eps=1e-5,
  527. momentum=0.997,
  528. has_bias=False,
  529. weight_init='normal',
  530. bias_init='zeros',
  531. beta_init='zeros',
  532. gamma_init='ones',
  533. mean_init='zeros',
  534. var_init='ones',
  535. fake=True,
  536. quant_config=quant_config_default,
  537. quant_dtype=QuantDtype.INT8):
  538. """Initialize Conv2dBnFoldQuant layer"""
  539. super(Conv2dBnFoldQuantOneConv, self).__init__()
  540. self.in_channels = Validator.check_positive_int(in_channels)
  541. self.out_channels = Validator.check_positive_int(out_channels)
  542. self.kernel_size = twice(kernel_size)
  543. self.stride = twice(stride)
  544. self.pad_mode = pad_mode
  545. self.padding = padding
  546. self.dilation = twice(dilation)
  547. self.group = group
  548. self.eps = eps
  549. self.momentum = 1 - momentum
  550. self.has_bias = has_bias
  551. self.fake = fake
  552. self.quant_config = quant_config
  553. self.quant_dtype = quant_dtype
  554. data_format = 'NCHW'
  555. self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
  556. self._target = context.get_context("device_target")
  557. self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
  558. if context.get_context("enable_ge"):
  559. self.is_ge_backend = True
  560. else:
  561. self.is_ge_backend = False
  562. self.enable_default_train = self.is_graph_mode and \
  563. (self.is_ge_backend or self._target == "Ascend")
  564. # initialize convolution op and Parameter
  565. self.conv = P.Conv2D(out_channel=out_channels,
  566. kernel_size=self.kernel_size,
  567. pad_mode=pad_mode,
  568. pad=padding,
  569. stride=self.stride,
  570. dilation=self.dilation,
  571. group=group)
  572. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  573. channel_axis = 0
  574. self.channel_axis = channel_axis
  575. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  576. self.bias_add = P.BiasAdd()
  577. if Validator.check_bool(has_bias):
  578. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  579. else:
  580. self.bias = None
  581. # initialize BatchNorm Parameter
  582. self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
  583. self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
  584. self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
  585. self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
  586. requires_grad=False)
  587. # initialize fake ops
  588. self.fake_quant_weight = quant_config.weight(ema=False,
  589. channel_axis=channel_axis,
  590. num_channels=out_channels,
  591. quant_dtype=quant_dtype)
  592. self.freeze_bn = False
  593. if self.fake_quant_weight.mode == "LEARNED_SCALE":
  594. self.freeze_bn = True
  595. self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps,
  596. momentum=self.momentum, data_format=self.format)
  597. self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
  598. self.sub_mean = P.Sub()
  599. self.sub_var = P.Sub()
  600. self.mul_mean = P.Mul()
  601. self.mul_var = P.Mul()
  602. self.assign_sub_mean = P.AssignSub()
  603. self.assign_sub_var = P.AssignSub()
  604. self.reshape = P.Reshape()
  605. def extend_repr(self):
  606. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  607. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  608. 'fake={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
  609. self.kernel_size, self.stride,
  610. self.pad_mode, self.padding, self.dilation,
  611. self.group,
  612. self.fake, self.momentum,
  613. self.fake_quant_weight.quant_delay)
  614. return s
  615. def construct(self, x):
  616. running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps))
  617. scale_factor = self.gamma / running_std
  618. if self.channel_axis:
  619. scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
  620. else:
  621. scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1))
  622. weight = self.weight * scale_factor
  623. if self.fake:
  624. weight = self.fake_quant_weight(weight)
  625. conv = self.conv(x, weight)
  626. if self.freeze_bn:
  627. return conv + self.reshape((self.beta - self.gamma * self.moving_mean / running_std), (1, -1, 1, 1))
  628. scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
  629. if self.enable_default_train:
  630. scale_factor = P.Reciprocal()(scale_factor)
  631. conv_orig = conv * scale_factor
  632. else:
  633. conv_orig = conv / scale_factor
  634. if self.training:
  635. return self.bn_train(conv_orig,
  636. self.gamma,
  637. self.beta,
  638. self.moving_mean,
  639. self.moving_variance)[0]
  640. return self.bn_infer(conv_orig,
  641. self.gamma,
  642. self.beta,
  643. self.moving_mean,
  644. self.moving_variance)[0]
  645. class Conv2dBnFoldQuant(Cell):
  646. r"""
  647. 2D convolution with Batch Normalization operation folded construct.
  648. This part is a more detailed overview of Conv2d operation. For more detials about Quantilization,
  649. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  650. Args:
  651. in_channels (int): The number of input channel :math:`C_{in}`.
  652. out_channels (int): The number of output channel :math:`C_{out}`.
  653. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  654. stride (int): Specifies stride for all spatial dimensions with the same value.
  655. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  656. padding (int): Implicit paddings on both sides of the input. Default: 0.
  657. eps (float): Parameters for Batch Normalization. Default: 1e-5.
  658. momentum (float): Parameters for Batch Normalization op. Default: 0.997.
  659. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  660. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  661. divisible by the number of groups. Default: 1.
  662. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  663. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  664. convolution kernel. Default: 'normal'.
  665. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  666. bias vector. Default: 'zeros'.
  667. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  668. beta vector. Default: 'zeros'.
  669. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  670. gamma vector. Default: 'ones'.
  671. mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  672. mean vector. Default: 'zeros'.
  673. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
  674. variance vector. Default: 'ones'.
  675. fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
  676. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  677. generated by compression.quant.create_quant_config method.
  678. Default: both set to default FakeQuantWithMinMaxObserver.
  679. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  680. freeze_bn (int): The quantization freeze Batch Normalization op is according to the global step.
  681. Default: 100000.
  682. Inputs:
  683. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  684. Outputs:
  685. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  686. Raises:
  687. TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
  688. TypeError: If `has_bias` is not a bool.
  689. ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
  690. ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
  691. Supported Platforms:
  692. ``Ascend`` ``GPU``
  693. Examples:
  694. >>> qconfig = compression.quant.create_quant_config()
  695. >>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
  696. ... quant_config=qconfig)
  697. >>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
  698. >>> output = conv2d_bnfold(input)
  699. >>> print(output.shape)
  700. (2, 6, 2, 2)
  701. """
  702. def __init__(self,
  703. in_channels,
  704. out_channels,
  705. kernel_size,
  706. stride=1,
  707. pad_mode='same',
  708. padding=0,
  709. dilation=1,
  710. group=1,
  711. eps=1e-5,
  712. momentum=0.997,
  713. has_bias=False,
  714. weight_init='normal',
  715. bias_init='zeros',
  716. beta_init='zeros',
  717. gamma_init='ones',
  718. mean_init='zeros',
  719. var_init='ones',
  720. fake=True,
  721. quant_config=quant_config_default,
  722. quant_dtype=QuantDtype.INT8,
  723. freeze_bn=100000):
  724. """Initialize Conv2dBnFoldQuant layer"""
  725. super(Conv2dBnFoldQuant, self).__init__()
  726. self.in_channels = Validator.check_positive_int(in_channels)
  727. self.out_channels = Validator.check_positive_int(out_channels)
  728. self.kernel_size = twice(kernel_size)
  729. self.stride = twice(stride)
  730. self.pad_mode = pad_mode
  731. self.padding = padding
  732. self.dilation = twice(dilation)
  733. self.group = group
  734. self.eps = eps
  735. self.momentum = momentum
  736. self.has_bias = has_bias
  737. self.freeze_bn = freeze_bn
  738. self.fake = fake
  739. self.quant_config = quant_config
  740. self.quant_dtype = quant_dtype
  741. self.is_gpu = context.get_context('device_target') == "GPU"
  742. # initialize convolution op and Parameter
  743. self.conv = P.Conv2D(out_channel=out_channels,
  744. kernel_size=self.kernel_size,
  745. pad_mode=pad_mode,
  746. pad=padding,
  747. stride=self.stride,
  748. dilation=self.dilation,
  749. group=group)
  750. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  751. channel_axis = 0
  752. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  753. self.bias_add = P.BiasAdd()
  754. if Validator.check_bool(has_bias):
  755. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  756. else:
  757. self.bias = None
  758. # initialize BatchNorm Parameter
  759. self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
  760. self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
  761. self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
  762. self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
  763. requires_grad=False)
  764. # initialize fake ops
  765. self.fake_quant_weight = quant_config.weight(ema=False,
  766. channel_axis=channel_axis,
  767. num_channels=out_channels,
  768. quant_dtype=quant_dtype)
  769. self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
  770. self.correct_mul = Q.CorrectionMul(channel_axis)
  771. if context.get_context('device_target') == "Ascend":
  772. self.batchnorm_fold2_train = Q.BatchNormFold2_D(freeze_bn=freeze_bn)
  773. self.batchnorm_fold2_infer = Q.BatchNormFold2_D(freeze_bn=0)
  774. elif context.get_context('device_target') == "GPU":
  775. self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
  776. self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
  777. else:
  778. raise ValueError("Unsupported platform: {}".format(context.get_context('device_target')))
  779. self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
  780. self.one = Tensor(1, mstype.int32)
  781. self.assignadd = P.AssignAdd()
  782. def extend_repr(self):
  783. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  784. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  785. 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
  786. self.kernel_size, self.stride,
  787. self.pad_mode, self.padding, self.dilation,
  788. self.group,
  789. self.fake, self.freeze_bn, self.momentum,
  790. self.fake_quant_weight.quant_delay)
  791. return s
  792. def construct(self, x):
  793. out_conv = self.conv(x, self.weight)
  794. if self.has_bias:
  795. out_conv = self.bias_add(out_conv, self.bias)
  796. # BN fold1
  797. batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
  798. self.moving_mean,
  799. self.moving_variance,
  800. self.step)
  801. # fake weight
  802. weight = self.correct_mul(self.weight, self.gamma, running_std)
  803. if self.fake:
  804. weight = self.fake_quant_weight(weight)
  805. out = self.conv(x, weight)
  806. if self.has_bias:
  807. out = self.bias_add(out, self.bias)
  808. # BN fold2
  809. if self.is_gpu:
  810. if self.training:
  811. out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
  812. batch_std, batch_mean, running_std, running_mean, self.step)
  813. self.assignadd(self.step, self.one)
  814. else:
  815. out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
  816. batch_std, batch_mean, running_std, running_mean, self.step)
  817. else:
  818. if self.training:
  819. out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
  820. self.assignadd(self.step, self.one)
  821. else:
  822. out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
  823. return out
  824. class Conv2dBnWithoutFoldQuant(Cell):
  825. r"""
  826. 2D convolution and batchnorm without fold with fake quantized construct.
  827. This part is a more detailed overview of Conv2d operation. For more detials about Quantilization,
  828. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  829. Args:
  830. in_channels (int): The number of input channel :math:`C_{in}`.
  831. out_channels (int): The number of output channel :math:`C_{out}`.
  832. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  833. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
  834. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  835. padding (int): Implicit paddings on both sides of the input. Default: 0.
  836. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  837. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  838. divisible by the number of groups. Default: 1.
  839. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  840. eps (float): Parameters for Batch Normalization. Default: 1e-5.
  841. momentum (float): Parameters for Batch Normalization op. Default: 0.997.
  842. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  843. Default: 'normal'.
  844. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
  845. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  846. generated by compression.quant.create_quant_config method.
  847. Default: both set to default FakeQuantWithMinMaxObserver.
  848. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  849. Inputs:
  850. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  851. Outputs:
  852. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  853. Supported Platforms:
  854. ``Ascend`` ``GPU``
  855. Examples:
  856. >>> qconfig = compression.quant.create_quant_config()
  857. >>> conv2d_no_bnfold = nn.Conv2dBnWithoutFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
  858. ... quant_config=qconfig)
  859. >>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mstype.float32)
  860. >>> output = conv2d_no_bnfold(input)
  861. >>> print(output.shape)
  862. (2, 6, 2, 2)
  863. """
  864. def __init__(self,
  865. in_channels,
  866. out_channels,
  867. kernel_size,
  868. stride=1,
  869. pad_mode='same',
  870. padding=0,
  871. dilation=1,
  872. group=1,
  873. has_bias=False,
  874. eps=1e-5,
  875. momentum=0.997,
  876. weight_init='normal',
  877. bias_init='zeros',
  878. quant_config=quant_config_default,
  879. quant_dtype=QuantDtype.INT8):
  880. super(Conv2dBnWithoutFoldQuant, self).__init__()
  881. self.in_channels = Validator.check_positive_int(in_channels)
  882. self.out_channels = Validator.check_positive_int(out_channels)
  883. self.has_bias = has_bias
  884. self.kernel_size = twice(kernel_size)
  885. self.stride = twice(stride)
  886. self.dilation = twice(dilation)
  887. self.pad_mode = pad_mode
  888. self.padding = padding
  889. self.group = group
  890. self.bias_add = P.BiasAdd()
  891. if Validator.check_bool(has_bias):
  892. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  893. else:
  894. self.bias = None
  895. # initialize convolution op and Parameter
  896. self.conv = P.Conv2D(out_channel=self.out_channels,
  897. kernel_size=self.kernel_size,
  898. mode=1,
  899. pad_mode=self.pad_mode,
  900. pad=self.padding,
  901. stride=self.stride,
  902. dilation=self.dilation,
  903. group=self.group)
  904. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  905. channel_axis = 0
  906. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  907. self.fake_quant_weight = quant_config.weight(ema=False,
  908. channel_axis=channel_axis,
  909. num_channels=out_channels,
  910. quant_dtype=quant_dtype)
  911. self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
  912. def construct(self, x):
  913. weight = self.fake_quant_weight(self.weight)
  914. out = self.conv(x, weight)
  915. if self.has_bias:
  916. out = self.bias_add(out, self.bias)
  917. out = self.batchnorm(out)
  918. return out
  919. def extend_repr(self):
  920. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  921. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  922. 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
  923. self.pad_mode, self.padding, self.dilation, self.group,
  924. self.has_bias, self.fake_quant_weight.quant_delay)
  925. return s
  926. class Conv2dQuant(Cell):
  927. r"""
  928. 2D convolution with fake quantized operation layer.
  929. This part is a more detailed overview of Conv2d operation. For more detials about Quantilization,
  930. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  931. Args:
  932. in_channels (int): The number of input channel :math:`C_{in}`.
  933. out_channels (int): The number of output channel :math:`C_{out}`.
  934. kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
  935. stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
  936. pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
  937. padding (int): Implicit paddings on both sides of the input. Default: 0.
  938. dilation (int): Specifies the dilation rate to use for dilated convolution. Default: 1.
  939. group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
  940. divisible by the number of groups. Default: 1.
  941. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
  942. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
  943. Default: 'normal'.
  944. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
  945. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  946. generated by compression.quant.create_quant_config method.
  947. Default: both set to default FakeQuantWithMinMaxObserver.
  948. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  949. Inputs:
  950. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  951. Outputs:
  952. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  953. Raises:
  954. TypeError: If `in_channels`, `out_channels`, `stride`, `padding` or `dilation` is not an int.
  955. TypeError: If `has_bias` is not a bool.
  956. ValueError: If `in_channels` or `out_channels` `stride`, `padding` or `dilation` is less than 1.
  957. ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
  958. Supported Platforms:
  959. ``Ascend`` ``GPU``
  960. Examples:
  961. >>> qconfig = compression.quant.create_quant_config()
  962. >>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid",
  963. ... quant_config=qconfig)
  964. >>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
  965. >>> output = conv2d_quant(input)
  966. >>> print(output.shape)
  967. (2, 6, 2, 2)
  968. """
  969. def __init__(self,
  970. in_channels,
  971. out_channels,
  972. kernel_size,
  973. stride=1,
  974. pad_mode='same',
  975. padding=0,
  976. dilation=1,
  977. group=1,
  978. has_bias=False,
  979. weight_init='normal',
  980. bias_init='zeros',
  981. quant_config=quant_config_default,
  982. quant_dtype=QuantDtype.INT8):
  983. super(Conv2dQuant, self).__init__()
  984. self.in_channels = Validator.check_positive_int(in_channels)
  985. self.out_channels = Validator.check_positive_int(out_channels)
  986. self.has_bias = has_bias
  987. self.kernel_size = twice(kernel_size)
  988. self.stride = twice(stride)
  989. self.dilation = twice(dilation)
  990. self.pad_mode = pad_mode
  991. self.padding = padding
  992. self.group = group
  993. weight_shape = [out_channels, in_channels // group, *self.kernel_size]
  994. self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
  995. self.bias_add = P.BiasAdd()
  996. if Validator.check_bool(has_bias):
  997. self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
  998. else:
  999. self.bias = None
  1000. self.conv = P.Conv2D(out_channel=self.out_channels,
  1001. kernel_size=self.kernel_size,
  1002. mode=1,
  1003. pad_mode=self.pad_mode,
  1004. pad=self.padding,
  1005. stride=self.stride,
  1006. dilation=self.dilation,
  1007. group=self.group)
  1008. channel_axis = 0
  1009. self.fake_quant_weight = quant_config.weight(ema=False,
  1010. channel_axis=channel_axis,
  1011. num_channels=out_channels,
  1012. quant_dtype=quant_dtype)
  1013. def construct(self, x):
  1014. weight = self.fake_quant_weight(self.weight)
  1015. out = self.conv(x, weight)
  1016. if self.has_bias:
  1017. return self.bias_add(out, self.bias)
  1018. return out
  1019. def extend_repr(self):
  1020. s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
  1021. 'pad_mode={}, padding={}, dilation={}, group={}, ' \
  1022. 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
  1023. self.pad_mode, self.padding, self.dilation, self.group,
  1024. self.has_bias, self.fake_quant_weight.quant_delay)
  1025. return s
  1026. class DenseQuant(Cell):
  1027. r"""
  1028. The fully connected layer with fake quantized operation.
  1029. This part is a more detailed overview of Dense operation. For more detials about Quantilization,
  1030. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  1031. Args:
  1032. in_channels (int): The dimension of the input space.
  1033. out_channels (int): The dimension of the output space.
  1034. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  1035. is same as input. The values of str refer to the function `initializer`. Default: 'normal'.
  1036. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  1037. same as input. The values of str refer to the function `initializer`. Default: 'zeros'.
  1038. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  1039. activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
  1040. eg. 'relu'. Default: None.
  1041. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  1042. generated by compression.quant.create_quant_config method.
  1043. Default: both set to default FakeQuantWithMinMaxObserver.
  1044. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1045. Inputs:
  1046. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  1047. Outputs:
  1048. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  1049. Raises:
  1050. TypeError: If `in_channels`, `out_channels` is not an int.
  1051. TypeError: If `has_bias` is not a bool.
  1052. ValueError: If `in_channels` or `out_channels` is less than 1.
  1053. Supported Platforms:
  1054. ``Ascend`` ``GPU``
  1055. Examples:
  1056. >>> qconfig = compression.quant.create_quant_config()
  1057. >>> dense_quant = nn.DenseQuant(3, 6, quant_config=qconfig)
  1058. >>> input = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
  1059. >>> result = dense_quant(input)
  1060. >>> output = result.shape
  1061. >>> print(output)
  1062. (2, 6)
  1063. """
  1064. def __init__(self,
  1065. in_channels,
  1066. out_channels,
  1067. weight_init='normal',
  1068. bias_init='zeros',
  1069. has_bias=True,
  1070. activation=None,
  1071. quant_config=quant_config_default,
  1072. quant_dtype=QuantDtype.INT8):
  1073. super(DenseQuant, self).__init__()
  1074. self.in_channels = Validator.check_positive_int(in_channels)
  1075. self.out_channels = Validator.check_positive_int(out_channels)
  1076. self.has_bias = Validator.check_bool(has_bias)
  1077. if isinstance(weight_init, Tensor):
  1078. if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
  1079. weight_init.shape[1] != in_channels:
  1080. raise ValueError("weight_init shape error")
  1081. self.weight = Parameter(initializer(
  1082. weight_init, [out_channels, in_channels]), name="weight")
  1083. if self.has_bias:
  1084. if isinstance(bias_init, Tensor):
  1085. if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
  1086. raise ValueError("bias_init shape error")
  1087. self.bias = Parameter(initializer(
  1088. bias_init, [out_channels]), name="bias")
  1089. self.matmul = P.MatMul(transpose_b=True)
  1090. self.bias_add = P.BiasAdd()
  1091. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  1092. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  1093. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  1094. self.activation_flag = self.activation is not None
  1095. self.fake_quant_weight = quant_config.weight(ema=False,
  1096. channel_axis=0,
  1097. num_channels=out_channels,
  1098. quant_dtype=quant_dtype)
  1099. def construct(self, x):
  1100. """Use operators to construct the Dense layer."""
  1101. output = self.fake_quant_weight(self.weight)
  1102. output = self.matmul(x, output)
  1103. if self.has_bias:
  1104. output = self.bias_add(output, self.bias)
  1105. if self.activation_flag:
  1106. return self.activation(output)
  1107. return output
  1108. def extend_repr(self):
  1109. """A pretty print for Dense layer."""
  1110. s = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
  1111. self.in_channels, self.out_channels, self.weight, self.has_bias)
  1112. if self.has_bias:
  1113. s += ', bias={}'.format(self.bias)
  1114. if self.activation_flag:
  1115. s += ', activation={}'.format(self.activation)
  1116. return s
  1117. class _QuantActivation(Cell):
  1118. r"""
  1119. Base class for quantization aware training activation function. Adds fake quantized operation
  1120. after activation operation.
  1121. """
  1122. def get_origin(self):
  1123. raise NotImplementedError
  1124. class ActQuant(_QuantActivation):
  1125. r"""
  1126. Quantization aware training activation function.
  1127. Add the fake quantized operation to the end of activation operation, by which the output of activation operation
  1128. will be truncated. For more detials about Quantilization,
  1129. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  1130. Args:
  1131. activation (Cell): Activation cell.
  1132. ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
  1133. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1134. fake_before (bool): Whether add fake quantized operation before activation. Default: False.
  1135. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  1136. generated by compression.quant.create_quant_config method.
  1137. Default: both set to default FakeQuantWithMinMaxObserver.
  1138. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1139. Inputs:
  1140. - **input** (Tensor) - The input of ActQuant.
  1141. Outputs:
  1142. Tensor, with the same type and shape as the `input`.
  1143. Raises:
  1144. TypeError: If `activation` is not an instance of Cell.
  1145. TypeError: If `fake_before` is not a bool.
  1146. Supported Platforms:
  1147. ``Ascend`` ``GPU``
  1148. Examples:
  1149. >>> qconfig = compression.quant.create_quant_config()
  1150. >>> act_quant = nn.ActQuant(nn.ReLU(), quant_config=qconfig)
  1151. >>> input = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
  1152. >>> output = act_quant(input)
  1153. >>> print(output)
  1154. [[0.9882355 1.9764705 0. ]
  1155. [0. 0. 0. ]]
  1156. """
  1157. def __init__(self,
  1158. activation,
  1159. ema=False,
  1160. ema_decay=0.999,
  1161. fake_before=False,
  1162. quant_config=quant_config_default,
  1163. quant_dtype=QuantDtype.INT8):
  1164. super(ActQuant, self).__init__()
  1165. act_class = activation.__class__
  1166. act_list = [nn.ReLU, nn.ReLU6]
  1167. self.act = Validator.check_isinstance("activation", activation, Cell)
  1168. self.fake_before = Validator.check_bool(fake_before, "fake_before")
  1169. if self.fake_before:
  1170. self.fake_quant_act_before = quant_config.activation(min_init=-6,
  1171. max_init=6,
  1172. ema=ema,
  1173. ema_decay=ema_decay,
  1174. quant_dtype=quant_dtype)
  1175. neg_trunc = bool(act_class in act_list)
  1176. self.fake_quant_act = quant_config.activation(min_init=-6,
  1177. max_init=6,
  1178. ema=ema,
  1179. ema_decay=ema_decay,
  1180. quant_dtype=quant_dtype,
  1181. neg_trunc=neg_trunc)
  1182. def construct(self, x):
  1183. if self.fake_before:
  1184. x = self.fake_quant_act_before(x)
  1185. x = self.act(x)
  1186. x = self.fake_quant_act(x)
  1187. return x
  1188. def get_origin(self):
  1189. return self.act
  1190. class TensorAddQuant(Cell):
  1191. r"""
  1192. Adds fake quantized operation after TensorAdd operation.
  1193. This part is a more detailed overview of TensorAdd operation. For more detials about Quantilization,
  1194. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  1195. Args:
  1196. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1197. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  1198. generated by compression.quant.create_quant_config method.
  1199. Default: both set to default FakeQuantWithMinMaxObserver.
  1200. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1201. Inputs:
  1202. - **input_x1** (Tensor) - The first tensor of TensorAddQuant.
  1203. - **input_x2** (Tensor) - The second tensor of TensorAddQuant.
  1204. Outputs:
  1205. Tensor, with the same type and shape as the `input_x1`.
  1206. Raises:
  1207. TypeError: If `ema_decay` is not a float.
  1208. Supported Platforms:
  1209. ``Ascend`` ``GPU``
  1210. Examples:
  1211. >>> qconfig = compression.quant.create_quant_config()
  1212. >>> add_quant = nn.TensorAddQuant(quant_config=qconfig)
  1213. >>> input_x1 = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  1214. >>> input_x2 = Tensor(np.ones((2, 3)), mindspore.float32)
  1215. >>> output = add_quant(input_x1, input_x2)
  1216. >>> print(output)
  1217. [[ 1.9764705 3.011765 1.9764705]
  1218. [-0.9882355 0.9882355 0. ]]
  1219. """
  1220. def __init__(self,
  1221. ema_decay=0.999,
  1222. quant_config=quant_config_default,
  1223. quant_dtype=QuantDtype.INT8):
  1224. super(TensorAddQuant, self).__init__()
  1225. self.fake_quant_act = quant_config.activation(min_init=-6,
  1226. max_init=6,
  1227. ema=True,
  1228. ema_decay=ema_decay,
  1229. quant_dtype=quant_dtype)
  1230. self.add = P.Add()
  1231. def construct(self, x1, x2):
  1232. x = self.add(x1, x2)
  1233. x = self.fake_quant_act(x)
  1234. return x
  1235. class MulQuant(Cell):
  1236. r"""
  1237. Adds fake quantized operation after `Mul` operation.
  1238. This part is a more detailed overview of `Mul` operation. For more detials about Quantilization,
  1239. please refer to :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
  1240. Args:
  1241. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
  1242. quant_config (QuantConfig): Configures the oberser types and quant settings of weight and activation. Can be
  1243. generated by compression.quant.create_quant_config method.
  1244. Default: both set to default :class:`FakeQuantWithMinMaxObserver`.
  1245. quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
  1246. Inputs:
  1247. - **input_x1** (Tensor) - The first tensor of MulQuant.
  1248. - **input_x2** (Tensor) - The second tensor of MulQuant.
  1249. Outputs:
  1250. Tensor, with the same type and shape as the `input_x1`.
  1251. Raises:
  1252. TypeError: If `ema_decay` is not a float.
  1253. Supported Platforms:
  1254. ``Ascend`` ``GPU``
  1255. Examples:
  1256. >>> qconfig = compression.quant.create_quant_config()
  1257. >>> mul_quant = nn.MulQuant(quant_config=qconfig)
  1258. >>> input_x1 = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
  1259. >>> input_x2 = Tensor(np.ones((2, 3)) * 2, mindspore.float32)
  1260. >>> output = mul_quant(input_x1, input_x2)
  1261. >>> print(output)
  1262. [[ 1.9764705 4.0000005 1.9764705]
  1263. [-4. 0. -1.9764705]]
  1264. """
  1265. def __init__(self,
  1266. ema_decay=0.999,
  1267. quant_config=quant_config_default,
  1268. quant_dtype=QuantDtype.INT8):
  1269. super(MulQuant, self).__init__()
  1270. self.fake_quant_act = quant_config.activation(min_init=-6,
  1271. max_init=6,
  1272. ema=True,
  1273. ema_decay=ema_decay,
  1274. quant_dtype=quant_dtype)
  1275. self.mul = P.Mul()
  1276. def construct(self, x1, x2):
  1277. x = self.mul(x1, x2)
  1278. x = self.fake_quant_act(x)
  1279. return x
  1280. class QuantBlock(Cell):
  1281. r"""
  1282. A quant block of Conv/Dense, activation layer for Ascend deploy.
  1283. Calculate Conv or Dense in Int8, with Quant and DeQuant.
  1284. Notes:
  1285. This block is only for deploy, and not trainable.
  1286. Args:
  1287. in_channels (int): The number of channels in the input space.
  1288. out_channels (int): The number of channels in the output space.
  1289. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  1290. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  1291. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  1292. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  1293. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  1294. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
  1295. batchnorm (bool): Specifies to used batchnorm or not. Default: None.
  1296. activation (string): Specifies activation type. The optional values are as following:
  1297. 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
  1298. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
  1299. Inputs:
  1300. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  1301. Outputs:
  1302. Tensor of shape :math:`(N, out\_channels)`.
  1303. """
  1304. def __init__(self,
  1305. core_op,
  1306. weight,
  1307. quant_op,
  1308. dequant_op,
  1309. dequant_scale,
  1310. bias=None,
  1311. activation=None):
  1312. super(QuantBlock, self).__init__()
  1313. self.core_op = core_op
  1314. self.weight = weight
  1315. self.quant = quant_op
  1316. self.dequant = dequant_op
  1317. self.dequant_scale = dequant_scale
  1318. self.bias = bias
  1319. self.has_bias = bias is not None
  1320. self.activation = activation
  1321. self.has_act = activation is not None
  1322. self.bias_add = P.BiasAdd()
  1323. self.sub = P.Sub()
  1324. self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
  1325. def construct(self, x):
  1326. x = self.quant(x)
  1327. if self.has_bias:
  1328. weight = self.sub(self.weight, self.weight_offset)
  1329. x = self.core_op(x, weight)
  1330. x = self.bias_add(x, self.bias)
  1331. else:
  1332. x = self.core_op(x, self.weight)
  1333. x = self.dequant(x, self.dequant_scale)
  1334. x = F.cast(x, mstype.float32)
  1335. if self.has_act:
  1336. x = self.activation(x)
  1337. return x
  1338. def extend_repr(self):
  1339. s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
  1340. if self.has_bias:
  1341. s += f', bias=shape[{self.bias.shape}]'
  1342. if self.has_act:
  1343. s += f', activation={self.activation}'
  1344. s += f', dequant={self.dequant}'
  1345. return s
  1346. class QuantMindirBlock(Cell):
  1347. """A quant binary block of Conv/Dense, activation layer for export MINDIR model.
  1348. Args:
  1349. core_op (Cell): The operation cell.
  1350. weight (Tensor): The weight of the cell.
  1351. bias (Tensor): The bias of the cell. Default: None.
  1352. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
  1353. param_dict (dict): The information of the cell.
  1354. """
  1355. def __init__(self,
  1356. core_op,
  1357. weight,
  1358. bias=None,
  1359. activation=None,
  1360. param_dict=None):
  1361. super(QuantMindirBlock, self).__init__()
  1362. self.core_op = core_op
  1363. if activation is not None:
  1364. self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
  1365. self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
  1366. self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
  1367. if param_dict["output_maxq"] is not None:
  1368. self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
  1369. self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
  1370. self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
  1371. if hasattr(core_op, 'pad_mode'):
  1372. self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
  1373. self.core_op.add_prim_attr("num_bits", Tensor(8))
  1374. self.core_op.add_prim_attr("narrow_range", Tensor(False))
  1375. if param_dict["input_maxq"] == 'None':
  1376. self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
  1377. self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
  1378. elif param_dict["input_maxq"] is not None:
  1379. self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
  1380. self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
  1381. self.weight = weight
  1382. self.bias = bias
  1383. self.has_bias = bias is not None
  1384. self.activation = activation
  1385. self.has_act = activation is not None
  1386. self.bias_add = P.BiasAdd()
  1387. if isinstance(activation, ReLU):
  1388. self.activation = None
  1389. self.has_act = False
  1390. def construct(self, x):
  1391. if self.has_bias:
  1392. x = self.core_op(x, self.weight)
  1393. x = self.bias_add(x, self.bias)
  1394. else:
  1395. x = self.core_op(x, self.weight)
  1396. return x
  1397. def extend_repr(self):
  1398. s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
  1399. if self.has_bias:
  1400. s += f', bias=shape[{self.bias.shape}]'
  1401. if self.has_act:
  1402. s += f', activation={self.activation}'
  1403. return s