You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_quant_ops.py 90 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0(the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http: // www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for quantization."""
  16. from functools import partial
  17. import mindspore.context as context
  18. from ..._checkparam import Validator as validator
  19. from ..._checkparam import Rel
  20. from ..primitive import PrimitiveWithInfer, prim_attr_register
  21. from ...common import dtype as mstype
  22. if context.get_context('device_target') == "Ascend":
  23. import mindspore.ops._op_impl._custom_op
  24. __all__ = ["MinMaxUpdatePerLayer",
  25. "MinMaxUpdatePerChannel",
  26. "FakeLearnedScaleQuantPerLayer",
  27. "FakeLearnedScaleQuantPerLayerGrad",
  28. "FakeLearnedScaleQuantPerLayerGradD",
  29. "FakeLearnedScaleQuantPerLayerGradDReduce",
  30. "FakeLearnedScaleQuantPerChannel",
  31. "FakeLearnedScaleQuantPerChannelGrad",
  32. "FakeLearnedScaleQuantPerChannelGradD",
  33. "FakeLearnedScaleQuantPerChannelGradDReduce",
  34. "FakeQuantWithMinMaxVars",
  35. "FakeQuantWithMinMaxVarsGradient",
  36. "FakeQuantWithMinMaxVarsPerChannel",
  37. "FakeQuantWithMinMaxVarsPerChannelGradient",
  38. "FakeQuantPerLayer",
  39. "FakeQuantPerLayerGrad",
  40. "FakeQuantPerChannel",
  41. "FakeQuantPerChannelGrad",
  42. "BatchNormFold",
  43. "BatchNormFoldGrad",
  44. "CorrectionMul",
  45. "CorrectionMulGrad",
  46. "CorrectionMulGradReduce",
  47. "BatchNormFold2",
  48. "BatchNormFold2Grad",
  49. "BatchNormFoldD",
  50. "BatchNormFoldGradD",
  51. "BatchNormFold2_D",
  52. "BatchNormFold2GradD",
  53. "BatchNormFold2GradReduce",
  54. "IFMR",
  55. "ActsULQ",
  56. "ActsULQInputGrad",
  57. "ActULQClampMinGrad",
  58. "ActULQClampMaxGrad",
  59. "WtsARQ"
  60. ]
  61. class MinMaxUpdatePerLayer(PrimitiveWithInfer):
  62. r"""
  63. Updates min and max per layer.
  64. Args:
  65. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  66. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  67. Inputs:
  68. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  69. - **min** (Tensor) : Value of the min range of the input data x.
  70. - **max** (Tensor) : Value of the max range of the input data x.
  71. Outputs:
  72. - Tensor: Simulates quantize tensor of x.
  73. Examples:
  74. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  75. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  76. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  77. >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  78. """
  79. support_quant_bit = [4, 7, 8]
  80. @prim_attr_register
  81. def __init__(self, ema=False, ema_decay=0.999):
  82. """Initialize FakeQuantMinMaxPerLayerUpdate OP"""
  83. if context.get_context('device_target') == "Ascend":
  84. from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
  85. if ema and not ema_decay:
  86. raise ValueError(
  87. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  88. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  89. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  90. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  91. outputs=['min_up', 'max_up'])
  92. def infer_shape(self, x_shape, min_shape, max_shape):
  93. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  94. validator.check("min shape", min_shape, "max shape",
  95. max_shape, Rel.EQ, self.name)
  96. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  97. return min_shape, max_shape
  98. def infer_dtype(self, x_type, min_type, max_type):
  99. tuple(map(partial(validator.check_tensor_dtype_valid,
  100. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  101. ("x", "min", "max"),
  102. (x_type, min_type, max_type)))
  103. return min_type, max_type
  104. class MinMaxUpdatePerChannel(PrimitiveWithInfer):
  105. r"""
  106. Updates min and max per channel.
  107. Args:
  108. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  109. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  110. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  111. Inputs:
  112. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  113. - **min** (Tensor) : Value of the min range of the input data x.
  114. - **max** (Tensor) : Value of the max range of the input data x.
  115. Outputs:
  116. - Tensor: Simulates quantize tensor of x.
  117. Examples:
  118. >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  119. >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  120. >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  121. >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
  122. """
  123. support_quant_bit = [4, 7, 8]
  124. ascend_support_x_rank = [2, 4]
  125. @prim_attr_register
  126. def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
  127. """Initialize FakeQuantPerChannelUpdate OP for Ascend"""
  128. self.is_ascend = context.get_context('device_target') == "Ascend"
  129. if self.is_ascend:
  130. from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
  131. if ema and not ema_decay:
  132. raise ValueError(
  133. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  134. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  135. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  136. if self.is_ascend:
  137. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  138. else:
  139. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  140. self.init_prim_io_names(
  141. inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
  142. def infer_shape(self, x_shape, min_shape, max_shape):
  143. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  144. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  145. if not self.is_ascend:
  146. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  147. validator.check("min shape", min_shape, "max shape",
  148. max_shape, Rel.EQ, self.name)
  149. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  150. return min_shape, max_shape
  151. def infer_dtype(self, x_type, min_type, max_type):
  152. tuple(map(partial(validator.check_tensor_dtype_valid,
  153. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  154. ("x", "min", "max"),
  155. (x_type, min_type, max_type)))
  156. return min_type, max_type
  157. class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
  158. r"""
  159. Simulates the quantize and dequantize operations of the fake learned scale quant per-layer case in training time.
  160. Args:
  161. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  162. simulate quantization aware function. After delay step in training time begin simulate the aware
  163. quantize function. Default: 0.
  164. neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False.
  165. training (bool): Training the network or not. Default: True.
  166. Inputs:
  167. - **input_x** (Tensor) : Input tensor that needs to be quantified.
  168. - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
  169. - **quant_max** (Tensor) : Value of the quantization range.
  170. Outputs:
  171. - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
  172. Examples:
  173. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  174. >>> alpha_tensor = Tensor(np.array([6]), mstype.float32)
  175. >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
  176. >>> output_tensor = FakeLearnedScaleQuantPerLayer()(input_tensor, alpha_tensor, quant_max_tensor)
  177. """
  178. @prim_attr_register
  179. def __init__(self,
  180. quant_delay=0,
  181. neg_trunc=False,
  182. training=True):
  183. """init FakeLearnedScaleQuantPerLayer OP"""
  184. if context.get_context('device_target') == "Ascend":
  185. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer
  186. self.quant_delay = validator.check_non_negative_int(
  187. quant_delay, 'quant_delay', self.name)
  188. self.neg_trunc = validator.check_value_type(
  189. 'neg_trunc', neg_trunc, (bool,), self.name)
  190. self.training = validator.check_value_type(
  191. 'training', training, (bool,), self.name)
  192. self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
  193. outputs=['out'])
  194. def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
  195. validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
  196. validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
  197. validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
  198. return input_x_shape
  199. def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
  200. if context.get_context('device_target') == "GPU":
  201. valid_dtypes = (mstype.float32,)
  202. else:
  203. valid_dtypes = (mstype.float16, mstype.float32)
  204. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  205. ("input_x", "alpha", "quant_max"),
  206. (input_x_type, alpha_type, quant_max_type)))
  207. return input_x_type
  208. class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
  209. r"""
  210. Performs grad of FakeLearnedScaleQuantPerLayer operation.
  211. Examples:
  212. >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerLayerGrad()
  213. >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
  214. >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
  215. >>> _alpha = Tensor(np.array([6]), mindspore.float32)
  216. >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
  217. >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
  218. """
  219. @prim_attr_register
  220. def __init__(self,
  221. quant_delay=0,
  222. neg_trunc=False):
  223. self.quant_delay = validator.check_non_negative_int(
  224. quant_delay, 'quant_delay', self.name)
  225. self.neg_trunc = validator.check_value_type(
  226. 'neg_trunc', neg_trunc, (bool,), self.name)
  227. self.init_prim_io_names(
  228. inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
  229. def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
  230. validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
  231. validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
  232. validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
  233. return dout_shape, alpha_shape
  234. def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
  235. if context.get_context('device_target') == "GPU":
  236. valid_dtypes = (mstype.float32,)
  237. else:
  238. valid_dtypes = (mstype.float16, mstype.float32)
  239. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  240. ("dout", "x", "alpha", "quant_max"),
  241. (dout_type, x_type, alpha_type, quant_max_type)))
  242. return dout_type, alpha_type
  243. class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
  244. r"""
  245. Performs input grad of FakeLearnedScaleQuantPerLayer operation.
  246. """
  247. @prim_attr_register
  248. def __init__(self,
  249. neg_trunc=False):
  250. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad
  251. self.neg_trunc = validator.check_value_type(
  252. 'neg_trunc', neg_trunc, (bool,), self.name)
  253. self.init_prim_io_names(
  254. inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
  255. def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
  256. validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
  257. validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
  258. validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
  259. return dout_shape, dout_shape
  260. def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
  261. valid_dtypes = (mstype.float16, mstype.float32)
  262. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  263. ("dout", "x", "alpha", "quant_max"),
  264. (dout_type, x_type, alpha_type, quant_max_type)))
  265. return dout_type, dout_type
  266. class FakeLearnedScaleQuantPerLayerGradDReduce(PrimitiveWithInfer):
  267. r"""
  268. Performs alpha grad reduce of FakeLearnedScaleQuantPerLayer operation.
  269. """
  270. @prim_attr_register
  271. def __init__(self):
  272. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad_reduce
  273. self.init_prim_io_names(
  274. inputs=['dout_alpha'], outputs=['dalpha'])
  275. def infer_shape(self, dout_alpha_shape):
  276. return (1,)
  277. def infer_dtype(self, dout_alpha_type):
  278. valid_dtypes = (mstype.float16, mstype.float32)
  279. validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
  280. return dout_alpha_type
  281. class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
  282. r"""
  283. Simulates the quantize and dequantize operations of the fake learned scale quant per-chnnel case in training time.
  284. Args:
  285. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  286. simulate quantization aware function. After delay step in training time begin simulate the aware
  287. quantize function. Default: 0.
  288. neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False.
  289. training (bool): Training the network or not. Default: True.
  290. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  291. Inputs:
  292. - **input_x** (Tensor) : Input tensor that needs to be quantified.
  293. - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
  294. - **quant_max** (Tensor) : Value of the quantization range.
  295. Outputs:
  296. - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
  297. Examples:
  298. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  299. >>> alpha_tensor = Tensor(np.array([6]*3), mstype.float32)
  300. >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
  301. >>> output_tensor = FakeLearnedScaleQuantPerChannel()(input_tensor, alpha_tensor, quant_max_tensor)
  302. """
  303. ascend_support_x_rank = [2, 4]
  304. @prim_attr_register
  305. def __init__(self,
  306. quant_delay=0,
  307. neg_trunc=False,
  308. training=True,
  309. channel_axis=1):
  310. """init FakeLearnedScaleQuantPerChannel OP"""
  311. if context.get_context('device_target') == "Ascend":
  312. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel
  313. self.is_ascend = context.get_context('device_target') == "Ascend"
  314. self.quant_delay = validator.check_non_negative_int(
  315. quant_delay, 'quant_delay', self.name)
  316. self.neg_trunc = validator.check_value_type(
  317. 'neg_trunc', neg_trunc, (bool,), self.name)
  318. self.training = validator.check_value_type(
  319. 'training', training, (bool,), self.name)
  320. if self.is_ascend:
  321. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  322. else:
  323. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  324. self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
  325. outputs=['out'])
  326. def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
  327. if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
  328. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  329. if not self.is_ascend:
  330. validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
  331. if len(input_x_shape) == 1:
  332. self.channel_axis = 0
  333. validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
  334. validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
  335. return input_x_shape
  336. def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
  337. if context.get_context('device_target') == "GPU":
  338. valid_dtypes = (mstype.float32,)
  339. else:
  340. valid_dtypes = (mstype.float16, mstype.float32)
  341. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  342. ("input_x", "alpha", "quant_max"),
  343. (input_x_type, alpha_type, quant_max_type)))
  344. return input_x_type
  345. class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
  346. r"""
  347. Performs grad of FakeLearnedScaleQuantPerChannel operation.
  348. Examples:
  349. >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerChannelGrad()
  350. >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
  351. >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
  352. >>> _alpha = Tensor(np.array([6]*2), mindspore.float32)
  353. >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
  354. >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
  355. """
  356. @prim_attr_register
  357. def __init__(self,
  358. quant_delay=0,
  359. neg_trunc=False,
  360. channel_axis=1):
  361. self.quant_delay = validator.check_non_negative_int(
  362. quant_delay, 'quant_delay', self.name)
  363. self.neg_trunc = validator.check_value_type(
  364. 'neg_trunc', neg_trunc, (bool,), self.name)
  365. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  366. self.init_prim_io_names(
  367. inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
  368. def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
  369. validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
  370. return dout_shape, alpha_shape
  371. def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
  372. if context.get_context('device_target') == "GPU":
  373. valid_dtypes = (mstype.float32,)
  374. else:
  375. valid_dtypes = (mstype.float16, mstype.float32)
  376. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  377. ("dout", "x", "alpha", "quant_max"),
  378. (dout_type, x_type, alpha_type, quant_max_type)))
  379. return dout_type, alpha_type
  380. class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
  381. r"""
  382. Performs input grad of FakeLearnedScaleQuantPerChannel operation.
  383. """
  384. @prim_attr_register
  385. def __init__(self,
  386. neg_trunc=False,
  387. channel_axis=1):
  388. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad
  389. self.neg_trunc = validator.check_value_type(
  390. 'neg_trunc', neg_trunc, (bool,), self.name)
  391. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  392. self.init_prim_io_names(
  393. inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
  394. def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
  395. validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
  396. validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
  397. validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
  398. return dout_shape, dout_shape
  399. def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
  400. valid_dtypes = (mstype.float16, mstype.float32)
  401. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  402. ("dout", "x", "alpha", "quant_max"),
  403. (dout_type, x_type, alpha_type, quant_max_type)))
  404. return dout_type, dout_type
  405. class FakeLearnedScaleQuantPerChannelGradDReduce(PrimitiveWithInfer):
  406. r"""
  407. Performs alpha grad reduce of FakeLearnedScaleQuantPerChannel operation.
  408. """
  409. @prim_attr_register
  410. def __init__(self, channel_axis=1):
  411. from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad_reduce
  412. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  413. self.init_prim_io_names(
  414. inputs=['dout_alpha'], outputs=['dalpha'])
  415. def infer_shape(self, dout_alpha_shape):
  416. return (dout_alpha_shape[self.channel_axis],)
  417. def infer_dtype(self, dout_alpha_type):
  418. valid_dtypes = (mstype.float16, mstype.float32)
  419. validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
  420. return dout_alpha_type
  421. class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
  422. r"""
  423. Fake-quantize the input by min and max.
  424. Args:
  425. num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
  426. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  427. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  428. range is [1, 2^num_bits-1]. Default: False.
  429. Inputs:
  430. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  431. - **min** (Tensor) - Value of the min range of the input data x.
  432. - **max** (Tensor) - Value of the max range of the input data x.
  433. Outputs:
  434. - Tensor, the data type and shape of output tensor is the same as input x.
  435. Examples:
  436. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  437. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  438. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  439. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  440. ... input_tensor, min_tensor, max_tensor)
  441. >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
  442. """
  443. @prim_attr_register
  444. def __init__(self,
  445. num_bits=8,
  446. narrow_range=False):
  447. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  448. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  449. self.narrow_range = validator.check_value_type(
  450. 'narrow_range', narrow_range, (bool,), self.name)
  451. def check_broadcast(self, min_shape, input_shape):
  452. shape_val = 1
  453. for shape in input_shape:
  454. shape_val = shape_val * shape
  455. if min_shape[0] > 1 and min_shape[0] != shape_val:
  456. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  457. def infer_shape(self, x_shape, min_shape, max_shape):
  458. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  459. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  460. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  461. self.check_broadcast(min_shape, x_shape)
  462. return x_shape
  463. def infer_dtype(self, x_type, min_type, max_type):
  464. tuple(map(partial(validator.check_tensor_dtype_valid,
  465. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  466. ("x", "min", "max"),
  467. (x_type, min_type, max_type)))
  468. return x_type
  469. class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
  470. r"""
  471. Performs grad of FakeQuantWithMinMaxVars operation.
  472. Args:
  473. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  474. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  475. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  476. range is [1, 2^num_bits-1]. Default: False.
  477. Inputs:
  478. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  479. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  480. - **min** (Tensor) - Value of the min range of the input data x.
  481. - **max** (Tensor) - Value of the max range of the input data x.
  482. Outputs:
  483. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  484. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  485. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  486. Examples:
  487. >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  488. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  489. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  490. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  491. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
  492. ... (gradients, input_tensor, min_tensor, max_tensor)
  493. >>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
  494. >>> min_gradient shape: (1,) data type: mstype.float32
  495. >>> max_gradient shape: (1,) data type: mstype.float32
  496. """
  497. @prim_attr_register
  498. def __init__(self,
  499. num_bits=8,
  500. narrow_range=False):
  501. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  502. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  503. self.narrow_range = validator.check_value_type(
  504. 'narrow_range', narrow_range, (bool,), self.name)
  505. def check_broadcast(self, min_shape, input_shape):
  506. shape_val = 1
  507. for shape in input_shape:
  508. shape_val = shape_val * shape
  509. if min_shape[0] > 1 and min_shape[0] != shape_val:
  510. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  511. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  512. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  513. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  514. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  515. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  516. self.check_broadcast(min_shape, x_shape)
  517. return x_shape, min_shape, max_shape
  518. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  519. tuple(map(partial(validator.check_tensor_dtype_valid,
  520. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  521. ('dout', "x", "min", "max"),
  522. (dout_type, x_type, min_type, max_type)))
  523. return x_type, min_type, max_type
  524. class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
  525. r"""
  526. Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
  527. Args:
  528. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  529. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  530. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  531. range is [1, 2^num_bits-1]. Default: False.
  532. Inputs:
  533. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  534. - **min** (Tensor) - Value of the min range of the input data x.
  535. - **max** (Tensor) - Value of the max range of the input data x.
  536. Outputs:
  537. - Tensor, the data type and shape of output tensor is the same as input x.
  538. Examples:
  539. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  540. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  541. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  542. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  543. ... input_tensor, min_tensor, max_tensor)
  544. >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
  545. """
  546. @prim_attr_register
  547. def __init__(self,
  548. num_bits=8,
  549. narrow_range=False):
  550. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  551. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  552. self.narrow_range = validator.check_value_type(
  553. 'narrow_range', narrow_range, (bool,), self.name)
  554. def infer_shape(self, x_shape, min_shape, max_shape):
  555. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  556. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  557. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  558. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  559. return x_shape
  560. def infer_dtype(self, x_type, min_type, max_type):
  561. tuple(map(partial(validator.check_tensor_dtype_valid,
  562. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  563. ("x", "min", "max"),
  564. (x_type, min_type, max_type)))
  565. return x_type
  566. class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
  567. r"""
  568. Performs grad of FakeQuantWithMinMaxVars operation.
  569. Args:
  570. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  571. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  572. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  573. range is [1, 2^num_bits-1]. Default: False.
  574. Inputs:
  575. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  576. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  577. - **min** (Tensor) - Value of the min range of the input data x.
  578. - **max** (Tensor) - Value of the max range of the input data x.
  579. Outputs:
  580. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  581. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  582. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  583. Examples:
  584. >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  585. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  586. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  587. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  588. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
  589. ... num_bits=8, narrow_range=False)(
  590. ... gradients, input_tensor, min_tensor, max_tensor)
  591. >>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
  592. >>> min_gradient shape: (4,) data type: mstype.float32
  593. >>> max_gradient shape: (4,) data type: mstype.float32
  594. """
  595. @prim_attr_register
  596. def __init__(self,
  597. num_bits=8,
  598. narrow_range=False):
  599. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  600. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  601. self.narrow_range = validator.check_value_type(
  602. 'narrow_range', narrow_range, (bool,), self.name)
  603. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  604. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  605. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  606. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  607. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  608. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  609. return x_shape, min_shape, max_shape
  610. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  611. tuple(map(partial(validator.check_tensor_dtype_valid,
  612. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  613. ("dout", "x", "min", "max"),
  614. (dout_type, x_type, min_type, max_type)))
  615. return x_type, min_type, max_type
  616. class FakeQuantPerLayer(PrimitiveWithInfer):
  617. r"""
  618. Simulates the quantize and dequantize operations in training time.
  619. Args:
  620. num_bits (int) : Number bits for quantization aware. Default: 8.
  621. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  622. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  623. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  624. simulate quantization aware function. After delay step in training time begin simulate the aware
  625. quantize function. Default: 0.
  626. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  627. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  628. training (bool): Training the network or not. Default: True.
  629. Inputs:
  630. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  631. - **min** (Tensor) : Value of the min range of the input data x.
  632. - **max** (Tensor) : Value of the max range of the input data x.
  633. Outputs:
  634. - Tensor: Simulates quantize tensor of x.
  635. Examples:
  636. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  637. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  638. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  639. >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  640. """
  641. support_quant_bit = [4, 7, 8]
  642. @prim_attr_register
  643. def __init__(self,
  644. num_bits=8,
  645. ema=False,
  646. ema_decay=0.999,
  647. quant_delay=0,
  648. symmetric=False,
  649. narrow_range=False,
  650. training=True):
  651. """Initialize FakeQuantPerLayer OP"""
  652. if context.get_context('device_target') == "Ascend":
  653. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
  654. if num_bits not in self.support_quant_bit:
  655. raise ValueError(
  656. f"For '{self.name}' attr \'num_bits\' is not support.")
  657. if ema and not ema_decay:
  658. raise ValueError(
  659. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  660. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  661. self.symmetric = validator.check_value_type(
  662. 'symmetric', symmetric, (bool,), self.name)
  663. self.narrow_range = validator.check_value_type(
  664. 'narrow_range', narrow_range, (bool,), self.name)
  665. self.training = validator.check_value_type('training', training, (bool,), self.name)
  666. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  667. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  668. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  669. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  670. outputs=['out'])
  671. def infer_shape(self, x_shape, min_shape, max_shape):
  672. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  673. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  674. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  675. return x_shape
  676. def infer_dtype(self, x_type, min_type, max_type):
  677. if context.get_context('device_target') == "GPU":
  678. valid_dtypes = (mstype.float32,)
  679. else:
  680. valid_dtypes = (mstype.float16, mstype.float32)
  681. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  682. ("x", "min", "max"),
  683. (x_type, min_type, max_type)))
  684. return x_type
  685. class FakeQuantPerLayerGrad(PrimitiveWithInfer):
  686. r"""
  687. Performs grad of FakeQuantPerLayerGrad operation.
  688. Examples:
  689. >>> fake_min_max_grad = FakeQuantPerLayerGrad()
  690. >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
  691. >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
  692. >>> _min = Tensor(np.array([-4]), mindspore.float32)
  693. >>> _max = Tensor(np.array([2]), mindspore.float32)
  694. >>> result = fake_min_max_grad(dout, input_x, _min, _max)
  695. """
  696. support_quant_bit = [4, 7, 8]
  697. @prim_attr_register
  698. def __init__(self,
  699. num_bits=8,
  700. quant_delay=0,
  701. symmetric=False,
  702. narrow_range=False):
  703. if context.get_context('device_target') == "Ascend":
  704. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
  705. if num_bits not in self.support_quant_bit:
  706. raise ValueError(
  707. f"For '{self.name}' attr \'num_bits\' is not support.")
  708. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  709. self.quant_delay = validator.check_value_type(
  710. 'quant_delay', quant_delay, (int,), self.name)
  711. self.symmetric = validator.check_value_type(
  712. 'symmetric', symmetric, (bool,), self.name)
  713. self.narrow_range = validator.check_value_type(
  714. 'narrow_range', narrow_range, (bool,), self.name)
  715. self.init_prim_io_names(
  716. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  717. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  718. validator.check("dout shape", dout_shape, "x shape",
  719. x_shape, Rel.EQ, self.name)
  720. validator.check("min shape", min_shape, "max shape",
  721. max_shape, Rel.EQ, self.name)
  722. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  723. return dout_shape
  724. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  725. if context.get_context('device_target') == "GPU":
  726. valid_dtypes = (mstype.float32,)
  727. else:
  728. valid_dtypes = (mstype.float16, mstype.float32)
  729. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  730. ("dout", "x", "min", "max"),
  731. (dout_type, x_type, min_type, max_type)))
  732. return dout_type
  733. class FakeQuantPerChannel(PrimitiveWithInfer):
  734. r"""
  735. Simulates the quantize and dequantize operations in training time base on per channel.
  736. Args:
  737. num_bits (int) : Number bits to quantilization. Default: 8.
  738. ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: False.
  739. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  740. quant_delay (int): Quantilization delay parameter. Before delay step in training time not
  741. update the weight data to simulate quantize operation. After delay step in training time
  742. begin simulate the quantize operation. Default: 0.
  743. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  744. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  745. training (bool): Training the network or not. Default: True.
  746. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  747. Inputs:
  748. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
  749. - **min** (int, float) : Value of the min range of the input data.
  750. - **max** (int, float) : Value of the max range of the input data.
  751. Outputs:
  752. - Tensor, has the same type as input.
  753. Examples:
  754. >>> fake_quant = FakeQuantPerChannel()
  755. >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
  756. >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
  757. >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
  758. >>> result = fake_quant(input_x, _min, _max)
  759. """
  760. support_quant_bit = [4, 7, 8]
  761. ascend_support_x_rank = [2, 4]
  762. @prim_attr_register
  763. def __init__(self,
  764. num_bits=8,
  765. ema=False,
  766. ema_decay=0.999,
  767. quant_delay=0,
  768. symmetric=False,
  769. narrow_range=False,
  770. training=True,
  771. channel_axis=1):
  772. """Initialize FakeQuantPerChannel OP"""
  773. self.is_ascend = context.get_context('device_target') == "Ascend"
  774. if self.is_ascend:
  775. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
  776. if num_bits not in self.support_quant_bit:
  777. raise ValueError(
  778. f"For '{self.name}' Attr \'num_bits\' is not support.")
  779. if ema and not ema_decay:
  780. raise ValueError(
  781. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  782. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  783. self.symmetric = validator.check_value_type(
  784. 'symmetric', symmetric, (bool,), self.name)
  785. self.narrow_range = validator.check_value_type(
  786. 'narrow_range', narrow_range, (bool,), self.name)
  787. self.training = validator.check_value_type(
  788. 'training', training, (bool,), self.name)
  789. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  790. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  791. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  792. if self.is_ascend:
  793. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  794. else:
  795. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  796. self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
  797. def infer_shape(self, x_shape, min_shape, max_shape):
  798. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  799. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  800. if not self.is_ascend:
  801. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  802. if len(x_shape) == 1:
  803. self.channel_axis = 0
  804. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  805. validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
  806. validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
  807. return x_shape
  808. def infer_dtype(self, x_type, min_type, max_type):
  809. if context.get_context('device_target') == "GPU":
  810. valid_dtypes = (mstype.float32,)
  811. else:
  812. valid_dtypes = (mstype.float16, mstype.float32)
  813. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  814. ("x", "min", "max"),
  815. (x_type, min_type, max_type)))
  816. return x_type
  817. class FakeQuantPerChannelGrad(PrimitiveWithInfer):
  818. r"""
  819. Performs grad of FakeQuantPerChannelGrad operation.
  820. Examples:
  821. >>> fqmmpc_grad = FakeQuantPerChannelGrad()
  822. >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
  823. >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
  824. >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
  825. >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
  826. >>> result = fqmmpc_grad(dout, input_x, _min, _max)
  827. """
  828. support_quant_bit = [4, 7, 8]
  829. @prim_attr_register
  830. def __init__(self,
  831. num_bits=8,
  832. quant_delay=0,
  833. symmetric=False,
  834. narrow_range=False,
  835. channel_axis=1):
  836. """Initialize FakeQuantPerChannelGrad Fill"""
  837. if context.get_context('device_target') == "Ascend":
  838. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
  839. if num_bits not in self.support_quant_bit:
  840. raise ValueError(
  841. f"For '{self.name}' attr \'num_bits\' is not support.")
  842. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  843. self.quant_delay = validator.check_value_type(
  844. 'quant_delay', quant_delay, (int,), self.name)
  845. self.symmetric = validator.check_value_type(
  846. 'symmetric', symmetric, (bool,), self.name)
  847. self.narrow_range = validator.check_value_type(
  848. 'narrow_range', narrow_range, (bool,), self.name)
  849. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  850. self.init_prim_io_names(
  851. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  852. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  853. validator.check("dout shape", dout_shape, "x shape", x_shape)
  854. validator.check("min shape", min_shape, "max shape", max_shape)
  855. return dout_shape
  856. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  857. if context.get_context('device_target') == "GPU":
  858. valid_dtypes = (mstype.float32,)
  859. else:
  860. valid_dtypes = (mstype.float16, mstype.float32)
  861. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  862. ("dout", "x", "min", "max"),
  863. (dout_type, x_type, min_type, max_type)))
  864. return dout_type
  865. class BatchNormFold(PrimitiveWithInfer):
  866. """
  867. Batch Normalization folded.
  868. Args:
  869. momentum (float): Momentum value must be [0, 1]. Default: 0.9.
  870. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  871. float32 else 1e-3. Default: 1e-5.
  872. is_training (bool): In training mode set True, else set False. Default: True.
  873. freeze_bn (int): Delay in steps at which computation switches from regular batch
  874. norm to frozen mean and std. Default: 0.
  875. Inputs:
  876. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  877. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  878. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  879. - **global_step** (Tensor) - Tensor to record current global step.
  880. Outputs:
  881. Tuple of 4 Tensor, the normalized input and the updated parameters.
  882. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  883. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  884. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  885. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  886. Examples:
  887. >>> batch_norm_fold = P.BatchNormFold()
  888. >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
  889. >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
  890. >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
  891. >>> global_step = Tensor(np.arange(6), mindspore.int32)
  892. >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
  893. """
  894. channel_axis = 1
  895. @prim_attr_register
  896. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  897. """Initialize batch norm fold layer"""
  898. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  899. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  900. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  901. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  902. self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
  903. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
  904. def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
  905. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  906. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  907. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  908. return mean_shape, mean_shape, mean_shape, mean_shape
  909. def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
  910. validator.check("input type", x_type, "mean type", mean_type)
  911. validator.check("input type", x_type, "variance type", variance_type)
  912. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  913. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  914. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  915. return x_type, x_type, x_type, x_type
  916. class BatchNormFoldGrad(PrimitiveWithInfer):
  917. r"""
  918. Performs grad of BatchNormFold operation.
  919. Examples:
  920. >>> batch_norm_fold_grad = ops.BatchNormFoldGrad()
  921. >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
  922. >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
  923. >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
  924. >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
  925. >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
  926. >>> global_step = Tensor([2], mindspore.int32)
  927. >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
  928. """
  929. channel_axis = 1
  930. @prim_attr_register
  931. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  932. """Initialize BatchNormGrad layer"""
  933. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  934. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  935. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  936. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
  937. outputs=['dx'])
  938. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
  939. global_step_shape):
  940. validator.check("d_batch_mean shape", d_batch_mean_shape,
  941. "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
  942. validator.check("d_batch_mean shape", d_batch_mean_shape,
  943. "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  944. validator.check("d_batch_mean shape", d_batch_mean_shape,
  945. "batch_std shape", batch_std_shape, Rel.EQ, self.name)
  946. validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
  947. "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  948. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  949. return x_shape
  950. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
  951. global_step_type):
  952. args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
  953. "batch_mean": batch_mean_type, "batch_std": batch_std_type}
  954. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  955. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  956. return x_type
  957. class CorrectionMul(PrimitiveWithInfer):
  958. """
  959. Scales the weights with a correction factor to the long term statistics
  960. prior to quantization. This ensures that there is no jitter in the quantized weights
  961. due to batch to batch variation.
  962. Inputs:
  963. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  964. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  965. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  966. Outputs:
  967. - **out** (Tensor) - Tensor has the same shape as x.
  968. Examples:
  969. >>> correction_mul = ops.CorrectionMul()
  970. >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
  971. >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
  972. >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
  973. >>> out = correction_mul(input_x, batch_std, running_std)
  974. """
  975. @prim_attr_register
  976. def __init__(self, channel_axis=0):
  977. """Initialize correction mul layer"""
  978. if context.get_context('device_target') == "Ascend":
  979. from mindspore.ops._op_impl._custom_op import correction_mul
  980. self.channel_axis = channel_axis
  981. self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
  982. outputs=['out'])
  983. def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
  984. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  985. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  986. Rel.EQ, self.name)
  987. return x_shape
  988. def infer_dtype(self, x_type, batch_std_type, running_std_type):
  989. args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
  990. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  991. return x_type
  992. class CorrectionMulGrad(PrimitiveWithInfer):
  993. r"""
  994. Performs grad of CorrectionMul operation.
  995. Examples:
  996. >>> correction_mul_grad = ops.CorrectionMulGrad()
  997. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  998. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  999. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  1000. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  1001. >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
  1002. """
  1003. @prim_attr_register
  1004. def __init__(self, channel_axis=0):
  1005. """Initialize correction mul layer"""
  1006. if context.get_context('device_target') == "Ascend":
  1007. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  1008. self.channel_axis = channel_axis
  1009. self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
  1010. outputs=['dx', 'mul_dx'])
  1011. def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
  1012. validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
  1013. validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
  1014. Rel.EQ, self.name)
  1015. validator.check("running_std_shape[0]", running_std_shape[0],
  1016. "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
  1017. if context.get_context('device_target') == "Ascend":
  1018. return x_shape, x_shape
  1019. return x_shape, gamma_shape
  1020. def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
  1021. args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
  1022. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1023. if context.get_context('device_target') == "Ascend":
  1024. return x_type, x_type
  1025. return x_type, gamma_type
  1026. class CorrectionMulGradReduce(PrimitiveWithInfer):
  1027. r"""
  1028. Performs grad reduce of CorrectionMul operation.
  1029. Examples:
  1030. >>> correction_mul_grad_rd = ops.CorrectionMulGradReduce()
  1031. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  1032. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  1033. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  1034. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  1035. >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
  1036. """
  1037. @prim_attr_register
  1038. def __init__(self, channel_axis=0):
  1039. """Initialize correction mul reduce layer"""
  1040. if context.get_context('device_target') == "Ascend":
  1041. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  1042. self.channel_axis = channel_axis
  1043. self.init_prim_io_names(inputs=['mul_dx'],
  1044. outputs=['d_gamma'])
  1045. def infer_shape(self, mul_dx_shape):
  1046. return [mul_dx_shape[self.channel_axis]]
  1047. def infer_dtype(self, mul_dx_type):
  1048. return mul_dx_type
  1049. class BatchNormFold2(PrimitiveWithInfer):
  1050. """
  1051. Scales the bias with a correction factor to the long term statistics
  1052. prior to quantization. This ensures that there is no jitter in the quantized bias
  1053. due to batch to batch variation.
  1054. Inputs:
  1055. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  1056. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  1057. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  1058. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  1059. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  1060. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  1061. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  1062. - **global_step** (Tensor) - Tensor to record current global step.
  1063. Outputs:
  1064. - **y** (Tensor) - Tensor has the same shape as x.
  1065. Examples:
  1066. >>> batch_norm_fold2 = ops.BatchNormFold2()
  1067. >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
  1068. >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
  1069. >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
  1070. >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
  1071. >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
  1072. >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
  1073. >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
  1074. >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
  1075. >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
  1076. >>> running_std, running_mean, global_step)
  1077. """
  1078. channel_axis = 1
  1079. @prim_attr_register
  1080. def __init__(self, freeze_bn=0):
  1081. """Initialize conv2d fold layer"""
  1082. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  1083. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
  1084. 'running_std', 'running_mean', 'global_step'],
  1085. outputs=['y'])
  1086. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
  1087. running_mean_shape, global_step_shape):
  1088. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  1089. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  1090. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  1091. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  1092. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  1093. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  1094. Rel.EQ, self.name)
  1095. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  1096. return x_shape
  1097. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
  1098. running_mean_type, global_step_type):
  1099. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  1100. "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
  1101. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1102. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  1103. return x_type
  1104. class BatchNormFold2Grad(PrimitiveWithInfer):
  1105. r"""
  1106. Performs grad of CorrectionAddGrad operation.
  1107. Examples:
  1108. >>> bnf2_grad = ops.BatchNormFold2Grad()
  1109. >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
  1110. >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
  1111. >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
  1112. >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
  1113. >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
  1114. >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
  1115. >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
  1116. >>> global_step = Tensor(np.array([-2]), mindspore.int32)
  1117. >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
  1118. """
  1119. channel_axis = 1
  1120. @prim_attr_register
  1121. def __init__(self, freeze_bn=0):
  1122. """Initialize MulFold layer"""
  1123. self.freeze_bn = freeze_bn
  1124. self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
  1125. 'batch_std', 'batch_mean',
  1126. 'running_std', 'running_mean', 'global_step'],
  1127. outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
  1128. def infer_shape(self, dout_shape, x_shape, gamma_shape,
  1129. batch_std_shape, batch_mean_shape,
  1130. running_std_shape, running_mean_shape, global_step_shape):
  1131. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  1132. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  1133. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  1134. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  1135. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  1136. Rel.EQ, self.name)
  1137. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  1138. return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
  1139. def infer_dtype(self, dout_type, x_type, gamma_type,
  1140. batch_std_type, batch_mean_type,
  1141. running_std_type, running_mean_type, global_step_type):
  1142. validator.check("batch_std type", batch_std_type,
  1143. "batch_mean type", batch_mean_type)
  1144. validator.check("batch_std type", batch_std_type,
  1145. "gamma type", gamma_type)
  1146. validator.check("batch_std type", batch_std_type,
  1147. "running_std type", running_std_type)
  1148. validator.check("batch_std type", batch_std_type,
  1149. "running_mean type", running_mean_type)
  1150. validator.check("batch_std_type", batch_std_type,
  1151. "dout type", dout_type)
  1152. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  1153. "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
  1154. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1155. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  1156. return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
  1157. class BatchNormFoldD(PrimitiveWithInfer):
  1158. """Performs grad of _BatchNormFold operation."""
  1159. @prim_attr_register
  1160. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  1161. """Initialize _BatchNormFold layer"""
  1162. from mindspore.ops._op_impl._custom_op import batchnorm_fold
  1163. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  1164. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  1165. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  1166. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  1167. self.data_format = "NCHW"
  1168. self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
  1169. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
  1170. 'mean_updated', 'variance_updated'])
  1171. def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
  1172. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  1173. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
  1174. return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
  1175. def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
  1176. validator.check("input type", x_type, "mean type", mean_type)
  1177. validator.check("input type", x_type, "variance type", variance_type)
  1178. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  1179. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1180. return x_type, x_type, x_type, x_type, x_type, x_type, x_type
  1181. class BatchNormFoldGradD(PrimitiveWithInfer):
  1182. """Performs grad of _BatchNormFoldGrad operation."""
  1183. @prim_attr_register
  1184. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  1185. """Initialize _BatchNormFoldGrad layer"""
  1186. from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
  1187. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  1188. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  1189. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  1190. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
  1191. outputs=['dx'])
  1192. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
  1193. validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
  1194. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
  1195. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
  1196. validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
  1197. return x_shape
  1198. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
  1199. validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
  1200. validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
  1201. validator.check("input type", x_type, "batch_mean type", batch_mean_type)
  1202. validator.check("input type", x_type, "batch_std type", batch_std_type)
  1203. validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
  1204. return x_type
  1205. class BatchNormFold2_D(PrimitiveWithInfer):
  1206. """
  1207. Scales the bias with a correction factor to the long term statistics
  1208. prior to quantization. This ensures that there is no jitter in the quantized bias
  1209. due to batch to batch variation.
  1210. Inputs:
  1211. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  1212. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  1213. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  1214. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  1215. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  1216. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  1217. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  1218. - **global_step** (Tensor) - Tensor to record current global step.
  1219. Outputs:
  1220. - **y** (Tensor) - Tensor has the same shape as x.
  1221. """
  1222. channel_axis = 1
  1223. @prim_attr_register
  1224. def __init__(self, freeze_bn=0):
  1225. """Initialize conv2d fold layer"""
  1226. from mindspore.ops._op_impl._custom_op import batchnorm_fold2
  1227. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  1228. outputs=['y'])
  1229. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
  1230. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  1231. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  1232. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  1233. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  1234. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  1235. Rel.EQ, self.name)
  1236. return x_shape
  1237. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
  1238. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  1239. "beta": beta_type, "gamma": gamma_type, "x": x_type}
  1240. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1241. return x_type
  1242. class BatchNormFold2GradD(PrimitiveWithInfer):
  1243. """Performs grad of CorrectionAddGrad operation."""
  1244. channel_axis = 1
  1245. @prim_attr_register
  1246. def __init__(self, freeze_bn=False):
  1247. """Initialize MulFold layer"""
  1248. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
  1249. self.freeze_bn = freeze_bn
  1250. self.init_prim_io_names(
  1251. inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  1252. outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
  1253. def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
  1254. batch_mean_shape, running_std_shape):
  1255. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  1256. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  1257. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  1258. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  1259. Rel.EQ, self.name)
  1260. return gamma_shape, gamma_shape, gamma_shape, dout_shape
  1261. def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
  1262. batch_mean_type, running_std_type):
  1263. validator.check("batch_std type", batch_std_type,
  1264. "batch_mean type", batch_mean_type)
  1265. validator.check("batch_std type", batch_std_type,
  1266. "gamma type", gamma_type)
  1267. validator.check("batch_std type", batch_std_type,
  1268. "running_std type", running_std_type)
  1269. validator.check("batch_std_type", batch_std_type,
  1270. "dout type", dout_type)
  1271. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  1272. "running_std": running_std_type, "dout": dout_type}
  1273. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1274. return gamma_type, gamma_type, gamma_type, gamma_type
  1275. class BatchNormFold2GradReduce(PrimitiveWithInfer):
  1276. """Performs grad of CorrectionAddGrad operation."""
  1277. channel_axis = 1
  1278. @prim_attr_register
  1279. def __init__(self, freeze_bn=False):
  1280. """Initialize MulFold layer"""
  1281. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
  1282. self.freeze_bn = freeze_bn
  1283. self.init_prim_io_names(inputs=['dout', 'x'],
  1284. outputs=['dout_reduce', 'dout_x_reduce'])
  1285. def infer_shape(self, dout_shape, x_shape):
  1286. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  1287. return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
  1288. def infer_dtype(self, dout_type, x_type):
  1289. validator.check("dout type", dout_type, "x type", x_type)
  1290. return dout_type, dout_type
  1291. class ActsULQ(PrimitiveWithInfer):
  1292. """
  1293. The ActsULQ(Activation universal learnable quantization).
  1294. Args:
  1295. fixed_min (bool): whether fix clamp min to zero.
  1296. num_bits (int): The bits num used for quantize.
  1297. Inputs:
  1298. - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
  1299. - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x.
  1300. - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x.
  1301. Outputs:
  1302. - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`.
  1303. - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min.
  1304. - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max.
  1305. - **x_clamped_loss** (Tensor) - A tensor of clamped loss.
  1306. Examples:
  1307. >>> data_type = np.float32
  1308. >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1309. >>> clamp_max = 0.7 * np.max(x)
  1310. >>> clamp_min = 0.7 * np.min(x)
  1311. >>> clamp_max = np.array([clamp_max], dtype=data_type)
  1312. >>> clamp_min = np.array([clamp_min], dtype=data_type)
  1313. >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8)
  1314. >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min),
  1315. Tensor(clamp_max))
  1316. """
  1317. @prim_attr_register
  1318. def __init__(self, fixed_min=False, num_bits=8):
  1319. validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
  1320. validator.check_value_type("num_bits", num_bits, [int], self.name)
  1321. validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
  1322. def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
  1323. """infer shape of primitive"""
  1324. validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
  1325. validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
  1326. x_shape_len = len(x_shape)
  1327. for i in range(x_shape_len):
  1328. validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
  1329. validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
  1330. return x_shape, x_shape, x_shape, x_shape
  1331. def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
  1332. """infer dtype of primitive"""
  1333. valid_types = [mstype.float32, mstype.float16]
  1334. validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
  1335. validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
  1336. validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)
  1337. return x_dtype, mstype.bool_, mstype.bool_, x_dtype
  1338. class ActsULQInputGrad(PrimitiveWithInfer):
  1339. """
  1340. The ActsULQInputGrad(grad of ActsULQ).
  1341. Inputs:
  1342. - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type.
  1343. Outputs:
  1344. - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`.
  1345. """
  1346. @prim_attr_register
  1347. def __init__(self):
  1348. pass
  1349. def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape):
  1350. return y_grad_shape
  1351. def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
  1352. valid_types = [mstype.float32, mstype.float16]
  1353. validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
  1354. return y_grad_type
  1355. class ActULQClampMinGrad(PrimitiveWithInfer):
  1356. """
  1357. The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient)
  1358. Inputs:
  1359. - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
  1360. - **clamp_min_mask** - A tensor of mask, only support int8 type.
  1361. - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
  1362. Outputs:
  1363. - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad".
  1364. The length of tensor is 1.
  1365. Examples:
  1366. >>> data_type = np.float32
  1367. >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1368. >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
  1369. >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1370. >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad()
  1371. >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_),
  1372. Tensor(x_clamped_loss))
  1373. """
  1374. @prim_attr_register
  1375. def __init__(self):
  1376. pass
  1377. def infer_shape(self, input_x, input_y, input_z):
  1378. input_x_len = len(input_x)
  1379. output_shape = []
  1380. for _ in range(input_x_len):
  1381. output_shape.append(1)
  1382. return tuple(output_shape)
  1383. def infer_dtype(self, input_x, input_y, input_z):
  1384. return mstype.float32
  1385. class ActULQClampMaxGrad(PrimitiveWithInfer):
  1386. """
  1387. The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient)
  1388. Inputs:
  1389. - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
  1390. - **clamp_max_mask** - A tensor of mask, only support int8 type.
  1391. - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
  1392. Outputs:
  1393. - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad".
  1394. The length of tensor is 1.
  1395. Examples:
  1396. >>> data_type = np.float32
  1397. >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1398. >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
  1399. >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1400. >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad()
  1401. >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_),
  1402. Tensor(x_clamped_loss))
  1403. """
  1404. @prim_attr_register
  1405. def __init__(self):
  1406. pass
  1407. def infer_shape(self, input_x, input_y, input_z):
  1408. input_x_len = len(input_x)
  1409. output_shape = []
  1410. for _ in range(input_x_len):
  1411. output_shape.append(1)
  1412. return tuple(output_shape)
  1413. def infer_dtype(self, input_x, input_y, input_z):
  1414. return mstype.float32
  1415. class WtsARQ(PrimitiveWithInfer):
  1416. """
  1417. The WtsARQ(Weights Adaptive Range Quantization).
  1418. Args:
  1419. axes (list): Specify channels for ARQ algorithm.
  1420. num_bits (int): The bits num used for quantize.
  1421. offset_flag (bool): Whether use offset for quantize.
  1422. Inputs:
  1423. - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type.
  1424. Outputs:
  1425. - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`.
  1426. - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`.
  1427. - If axis is [],
  1428. the shape of scale and offset is :math:`(1, )`.
  1429. - If axis is [0],
  1430. the shape of scale and offset is :math:`(w_1, )`.
  1431. - If axis is [1],
  1432. the shape of scale and offset is :math:`(w_2, )`.
  1433. - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`.
  1434. Examples:
  1435. >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
  1436. >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False)
  1437. >>> scale, offset, y = wts_arq(data)
  1438. """
  1439. @prim_attr_register
  1440. def __init__(self, num_bits, offset_flag):
  1441. validator.check_value_type("num_bits", num_bits, [int], self.name)
  1442. validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
  1443. validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
  1444. def infer_shape(self, w_shape, w_min_shape, w_max_shape):
  1445. validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
  1446. validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
  1447. return w_shape
  1448. def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
  1449. valid_types = [mstype.float32, mstype.float16]
  1450. validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
  1451. validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
  1452. validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
  1453. return w_dtype
  1454. class IFMR(PrimitiveWithInfer):
  1455. """
  1456. The TFMR(Input Feature Map Reconstruction).
  1457. Args:
  1458. min_percentile (float): Min init percentile. Default: 0.999999.
  1459. max_percentile (float): Max init percentile. Default: 0.999999.
  1460. search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
  1461. search_step (float): Step size of searching. Default: 0.01.
  1462. with_offset (bool): Whether using offset. Default: True.
  1463. Inputs:
  1464. - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
  1465. - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
  1466. With float16 or float32 data type.
  1467. - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
  1468. With float16 or float32 data type.
  1469. - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
  1470. Outputs:
  1471. - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
  1472. - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
  1473. Examples:
  1474. >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
  1475. >>> data_min = Tensor([0.1], mindspore.float32)
  1476. >>> data_max = Tensor([0.5], mindspore.float32)
  1477. >>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
  1478. >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
  1479. ... search_step=1.0, with_offset=False)
  1480. >>> output = ifmr(data, data_min, data_max, cumsum)
  1481. >>> print(output)
  1482. (Tensor(shape=[1], dtype=Float32, value= [7.87401572e-03]),
  1483. Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
  1484. """
  1485. @prim_attr_register
  1486. def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
  1487. with_offset=True):
  1488. validator.check_value_type("min_percentile", min_percentile, [float], self.name)
  1489. validator.check_value_type("max_percentile", max_percentile, [float], self.name)
  1490. validator.check_value_type("search_range", search_range, [list, tuple], self.name)
  1491. for item in search_range:
  1492. validator.check_positive_float(item, "item of search_range", self.name)
  1493. validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
  1494. validator.check_value_type("search_step", search_step, [float], self.name)
  1495. validator.check_value_type("offset_flag", with_offset, [bool], self.name)
  1496. def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
  1497. validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
  1498. validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
  1499. validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
  1500. validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
  1501. validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
  1502. return (1,), (1,)
  1503. def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
  1504. tuple(map(partial(validator.check_tensor_dtype_valid,
  1505. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  1506. ("input_value", "input_min", "input_max"),
  1507. (data_dtype, data_min_dtype, data_max_dtype)))
  1508. validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
  1509. return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)