You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_quant_ops.py 74 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0(the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http: // www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for quantization."""
  16. from functools import partial
  17. import mindspore.context as context
  18. from ..._checkparam import Validator as validator
  19. from ..._checkparam import Rel
  20. from ..primitive import PrimitiveWithInfer, prim_attr_register
  21. from ...common import dtype as mstype
  22. __all__ = ["MinMaxUpdatePerLayer",
  23. "MinMaxUpdatePerChannel",
  24. "FakeQuantWithMinMaxVars",
  25. "FakeQuantWithMinMaxVarsGradient",
  26. "FakeQuantWithMinMaxVarsPerChannel",
  27. "FakeQuantWithMinMaxVarsPerChannelGradient",
  28. "FakeQuantPerLayer",
  29. "FakeQuantPerLayerGrad",
  30. "FakeQuantPerChannel",
  31. "FakeQuantPerChannelGrad",
  32. "BatchNormFold",
  33. "BatchNormFoldGrad",
  34. "CorrectionMul",
  35. "CorrectionMulGrad",
  36. "CorrectionMulGradReduce",
  37. "BatchNormFold2",
  38. "BatchNormFold2Grad",
  39. "BatchNormFoldD",
  40. "BatchNormFoldGradD",
  41. "BatchNormFold2_D",
  42. "BatchNormFold2GradD",
  43. "BatchNormFold2GradReduce",
  44. "IFMR",
  45. "ActsULQ",
  46. "ActsULQInputGrad",
  47. "ActULQClampMinGrad",
  48. "ActULQClampMaxGrad",
  49. "WtsARQ"
  50. ]
  51. class MinMaxUpdatePerLayer(PrimitiveWithInfer):
  52. r"""
  53. Updates min and max per layer.
  54. Args:
  55. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  56. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  57. Inputs:
  58. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  59. - **min** (Tensor) : Value of the min range of the input data x.
  60. - **max** (Tensor) : Value of the max range of the input data x.
  61. Outputs:
  62. - Tensor: Simulates quantize tensor of x.
  63. Examples:
  64. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  65. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  66. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  67. >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  68. """
  69. support_quant_bit = [4, 7, 8]
  70. @prim_attr_register
  71. def __init__(self, ema=False, ema_decay=0.999):
  72. """Initialize FakeQuantMinMaxPerLayerUpdate OP"""
  73. if context.get_context('device_target') == "Ascend":
  74. from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
  75. if ema and not ema_decay:
  76. raise ValueError(
  77. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  78. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  79. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  80. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  81. outputs=['min_up', 'max_up'])
  82. def infer_shape(self, x_shape, min_shape, max_shape):
  83. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  84. validator.check("min shape", min_shape, "max shape",
  85. max_shape, Rel.EQ, self.name)
  86. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  87. return min_shape, max_shape
  88. def infer_dtype(self, x_type, min_type, max_type):
  89. tuple(map(partial(validator.check_tensor_dtype_valid,
  90. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  91. ("x", "min", "max"),
  92. (x_type, min_type, max_type)))
  93. return min_type, max_type
  94. class MinMaxUpdatePerChannel(PrimitiveWithInfer):
  95. r"""
  96. Updates min and max per channel.
  97. Args:
  98. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  99. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  100. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  101. Inputs:
  102. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  103. - **min** (Tensor) : Value of the min range of the input data x.
  104. - **max** (Tensor) : Value of the max range of the input data x.
  105. Outputs:
  106. - Tensor: Simulates quantize tensor of x.
  107. Examples:
  108. >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  109. >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  110. >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  111. >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
  112. """
  113. support_quant_bit = [4, 7, 8]
  114. ascend_support_x_rank = [2, 4]
  115. @prim_attr_register
  116. def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
  117. """Initialize FakeQuantPerChannelUpdate OP for Ascend"""
  118. self.is_ascend = context.get_context('device_target') == "Ascend"
  119. if self.is_ascend:
  120. from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
  121. if ema and not ema_decay:
  122. raise ValueError(
  123. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  124. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  125. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  126. if self.is_ascend:
  127. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  128. else:
  129. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  130. self.init_prim_io_names(
  131. inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
  132. def infer_shape(self, x_shape, min_shape, max_shape):
  133. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  134. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  135. if not self.is_ascend:
  136. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  137. validator.check("min shape", min_shape, "max shape",
  138. max_shape, Rel.EQ, self.name)
  139. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  140. return min_shape, max_shape
  141. def infer_dtype(self, x_type, min_type, max_type):
  142. tuple(map(partial(validator.check_tensor_dtype_valid,
  143. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  144. ("x", "min", "max"),
  145. (x_type, min_type, max_type)))
  146. return min_type, max_type
  147. class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
  148. r"""
  149. Fake-quantize the input by min and max.
  150. Args:
  151. num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
  152. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  153. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  154. range is [1, 2^num_bits-1]. Default: False.
  155. Inputs:
  156. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  157. - **min** (Tensor) - Value of the min range of the input data x.
  158. - **max** (Tensor) - Value of the max range of the input data x.
  159. Outputs:
  160. - Tensor, the data type and shape of output tensor is the same as input x.
  161. Examples:
  162. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  163. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  164. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  165. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  166. ... input_tensor, min_tensor, max_tensor)
  167. >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
  168. """
  169. @prim_attr_register
  170. def __init__(self,
  171. num_bits=8,
  172. narrow_range=False):
  173. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  174. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  175. self.narrow_range = validator.check_value_type(
  176. 'narrow_range', narrow_range, (bool,), self.name)
  177. def check_broadcast(self, min_shape, input_shape):
  178. shape_val = 1
  179. for shape in input_shape:
  180. shape_val = shape_val * shape
  181. if min_shape[0] > 1 and min_shape[0] != shape_val:
  182. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  183. def infer_shape(self, x_shape, min_shape, max_shape):
  184. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  185. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  186. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  187. self.check_broadcast(min_shape, x_shape)
  188. return x_shape
  189. def infer_dtype(self, x_type, min_type, max_type):
  190. tuple(map(partial(validator.check_tensor_dtype_valid,
  191. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  192. ("x", "min", "max"),
  193. (x_type, min_type, max_type)))
  194. return x_type
  195. class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
  196. r"""
  197. Performs grad of FakeQuantWithMinMaxVars operation.
  198. Args:
  199. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  200. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  201. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  202. range is [1, 2^num_bits-1]. Default: False.
  203. Inputs:
  204. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  205. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  206. - **min** (Tensor) - Value of the min range of the input data x.
  207. - **max** (Tensor) - Value of the max range of the input data x.
  208. Outputs:
  209. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  210. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  211. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  212. Examples:
  213. >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  214. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  215. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  216. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  217. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
  218. ... (gradients, input_tensor, min_tensor, max_tensor)
  219. >>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
  220. >>> min_gradient shape: (1,) data type: mstype.float32
  221. >>> max_gradient shape: (1,) data type: mstype.float32
  222. """
  223. @prim_attr_register
  224. def __init__(self,
  225. num_bits=8,
  226. narrow_range=False):
  227. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  228. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  229. self.narrow_range = validator.check_value_type(
  230. 'narrow_range', narrow_range, (bool,), self.name)
  231. def check_broadcast(self, min_shape, input_shape):
  232. shape_val = 1
  233. for shape in input_shape:
  234. shape_val = shape_val * shape
  235. if min_shape[0] > 1 and min_shape[0] != shape_val:
  236. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  237. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  238. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  239. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  240. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  241. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  242. self.check_broadcast(min_shape, x_shape)
  243. return x_shape, min_shape, max_shape
  244. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  245. tuple(map(partial(validator.check_tensor_dtype_valid,
  246. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  247. ('dout', "x", "min", "max"),
  248. (dout_type, x_type, min_type, max_type)))
  249. return x_type, min_type, max_type
  250. class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
  251. r"""
  252. Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
  253. Args:
  254. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  255. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  256. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  257. range is [1, 2^num_bits-1]. Default: False.
  258. Inputs:
  259. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  260. - **min** (Tensor) - Value of the min range of the input data x.
  261. - **max** (Tensor) - Value of the max range of the input data x.
  262. Outputs:
  263. - Tensor, the data type and shape of output tensor is the same as input x.
  264. Examples:
  265. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  266. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  267. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  268. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  269. ... input_tensor, min_tensor, max_tensor)
  270. >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
  271. """
  272. @prim_attr_register
  273. def __init__(self,
  274. num_bits=8,
  275. narrow_range=False):
  276. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  277. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  278. self.narrow_range = validator.check_value_type(
  279. 'narrow_range', narrow_range, (bool,), self.name)
  280. def infer_shape(self, x_shape, min_shape, max_shape):
  281. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  282. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  283. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  284. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  285. return x_shape
  286. def infer_dtype(self, x_type, min_type, max_type):
  287. tuple(map(partial(validator.check_tensor_dtype_valid,
  288. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  289. ("x", "min", "max"),
  290. (x_type, min_type, max_type)))
  291. return x_type
  292. class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
  293. r"""
  294. Performs grad of FakeQuantWithMinMaxVars operation.
  295. Args:
  296. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  297. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  298. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  299. range is [1, 2^num_bits-1]. Default: False.
  300. Inputs:
  301. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  302. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  303. - **min** (Tensor) - Value of the min range of the input data x.
  304. - **max** (Tensor) - Value of the max range of the input data x.
  305. Outputs:
  306. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  307. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  308. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  309. Examples:
  310. >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  311. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  312. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  313. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  314. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
  315. ... num_bits=8, narrow_range=False)(
  316. ... gradients, input_tensor, min_tensor, max_tensor)
  317. >>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
  318. >>> min_gradient shape: (4,) data type: mstype.float32
  319. >>> max_gradient shape: (4,) data type: mstype.float32
  320. """
  321. @prim_attr_register
  322. def __init__(self,
  323. num_bits=8,
  324. narrow_range=False):
  325. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  326. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  327. self.narrow_range = validator.check_value_type(
  328. 'narrow_range', narrow_range, (bool,), self.name)
  329. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  330. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  331. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  332. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  333. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  334. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  335. return x_shape, min_shape, max_shape
  336. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  337. tuple(map(partial(validator.check_tensor_dtype_valid,
  338. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  339. ("dout", "x", "min", "max"),
  340. (dout_type, x_type, min_type, max_type)))
  341. return x_type, min_type, max_type
  342. class FakeQuantPerLayer(PrimitiveWithInfer):
  343. r"""
  344. Simulates the quantize and dequantize operations in training time.
  345. Args:
  346. num_bits (int) : Number bits for quantization aware. Default: 8.
  347. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  348. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  349. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  350. simulate quantization aware funcion. After delay step in training time begin simulate the aware
  351. quantize funcion. Default: 0.
  352. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  353. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  354. training (bool): Training the network or not. Default: True.
  355. Inputs:
  356. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  357. - **min** (Tensor) : Value of the min range of the input data x.
  358. - **max** (Tensor) : Value of the max range of the input data x.
  359. Outputs:
  360. - Tensor: Simulates quantize tensor of x.
  361. Examples:
  362. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  363. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  364. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  365. >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  366. """
  367. support_quant_bit = [4, 7, 8]
  368. @prim_attr_register
  369. def __init__(self,
  370. num_bits=8,
  371. ema=False,
  372. ema_decay=0.999,
  373. quant_delay=0,
  374. symmetric=False,
  375. narrow_range=False,
  376. training=True):
  377. """Initialize FakeQuantPerLayer OP"""
  378. if context.get_context('device_target') == "Ascend":
  379. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
  380. if num_bits not in self.support_quant_bit:
  381. raise ValueError(
  382. f"For '{self.name}' attr \'num_bits\' is not support.")
  383. if ema and not ema_decay:
  384. raise ValueError(
  385. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  386. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  387. self.symmetric = validator.check_value_type(
  388. 'symmetric', symmetric, (bool,), self.name)
  389. self.narrow_range = validator.check_value_type(
  390. 'narrow_range', narrow_range, (bool,), self.name)
  391. self.training = validator.check_value_type('training', training, (bool,), self.name)
  392. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  393. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  394. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  395. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  396. outputs=['out'])
  397. def infer_shape(self, x_shape, min_shape, max_shape):
  398. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  399. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  400. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  401. return x_shape
  402. def infer_dtype(self, x_type, min_type, max_type):
  403. if context.get_context('device_target') == "GPU":
  404. valid_dtypes = (mstype.float32,)
  405. else:
  406. valid_dtypes = (mstype.float16, mstype.float32)
  407. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  408. ("x", "min", "max"),
  409. (x_type, min_type, max_type)))
  410. return x_type
  411. class FakeQuantPerLayerGrad(PrimitiveWithInfer):
  412. r"""
  413. Performs grad of FakeQuantPerLayerGrad operation.
  414. Examples:
  415. >>> fake_min_max_grad = FakeQuantPerLayerGrad()
  416. >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
  417. >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
  418. >>> _min = Tensor(np.array([-4]), mindspore.float32)
  419. >>> _max = Tensor(np.array([2]), mindspore.float32)
  420. >>> result = fake_min_max_grad(dout, input_x, _min, _max)
  421. """
  422. support_quant_bit = [4, 7, 8]
  423. @prim_attr_register
  424. def __init__(self,
  425. num_bits=8,
  426. quant_delay=0,
  427. symmetric=False,
  428. narrow_range=False):
  429. if context.get_context('device_target') == "Ascend":
  430. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
  431. if num_bits not in self.support_quant_bit:
  432. raise ValueError(
  433. f"For '{self.name}' attr \'num_bits\' is not support.")
  434. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  435. self.quant_delay = validator.check_value_type(
  436. 'quant_delay', quant_delay, (int,), self.name)
  437. self.symmetric = validator.check_value_type(
  438. 'symmetric', symmetric, (bool,), self.name)
  439. self.narrow_range = validator.check_value_type(
  440. 'narrow_range', narrow_range, (bool,), self.name)
  441. self.init_prim_io_names(
  442. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  443. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  444. validator.check("dout shape", dout_shape, "x shape",
  445. x_shape, Rel.EQ, self.name)
  446. validator.check("min shape", min_shape, "max shape",
  447. max_shape, Rel.EQ, self.name)
  448. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  449. return dout_shape
  450. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  451. if context.get_context('device_target') == "GPU":
  452. valid_dtypes = (mstype.float32,)
  453. else:
  454. valid_dtypes = (mstype.float16, mstype.float32)
  455. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  456. ("dout", "x", "min", "max"),
  457. (dout_type, x_type, min_type, max_type)))
  458. return dout_type
  459. class FakeQuantPerChannel(PrimitiveWithInfer):
  460. r"""
  461. Simulates the quantize and dequantize operations in training time base on per channel.
  462. Args:
  463. num_bits (int) : Number bits to quantilization. Default: 8.
  464. ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: False.
  465. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  466. quant_delay (int): Quantilization delay parameter. Before delay step in training time not
  467. update the weight data to simulate quantize operation. After delay step in training time
  468. begin simulate the quantize operation. Default: 0.
  469. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  470. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  471. training (bool): Training the network or not. Default: True.
  472. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  473. Inputs:
  474. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
  475. - **min** (int, float) : Value of the min range of the input data.
  476. - **max** (int, float) : Value of the max range of the input data.
  477. Outputs:
  478. - Tensor, has the same type as input.
  479. Examples:
  480. >>> fake_quant = FakeQuantPerChannel()
  481. >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
  482. >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
  483. >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
  484. >>> result = fake_quant(input_x, _min, _max)
  485. """
  486. support_quant_bit = [4, 7, 8]
  487. ascend_support_x_rank = [2, 4]
  488. @prim_attr_register
  489. def __init__(self,
  490. num_bits=8,
  491. ema=False,
  492. ema_decay=0.999,
  493. quant_delay=0,
  494. symmetric=False,
  495. narrow_range=False,
  496. training=True,
  497. channel_axis=1):
  498. """Initialize FakeQuantPerChannel OP"""
  499. self.is_ascend = context.get_context('device_target') == "Ascend"
  500. if self.is_ascend:
  501. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
  502. if num_bits not in self.support_quant_bit:
  503. raise ValueError(
  504. f"For '{self.name}' Attr \'num_bits\' is not support.")
  505. if ema and not ema_decay:
  506. raise ValueError(
  507. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  508. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  509. self.symmetric = validator.check_value_type(
  510. 'symmetric', symmetric, (bool,), self.name)
  511. self.narrow_range = validator.check_value_type(
  512. 'narrow_range', narrow_range, (bool,), self.name)
  513. self.training = validator.check_value_type(
  514. 'training', training, (bool,), self.name)
  515. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  516. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  517. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  518. if self.is_ascend:
  519. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  520. else:
  521. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  522. self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
  523. def infer_shape(self, x_shape, min_shape, max_shape):
  524. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  525. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  526. if not self.is_ascend:
  527. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  528. if len(x_shape) == 1:
  529. self.channel_axis = 0
  530. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  531. validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
  532. validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
  533. return x_shape
  534. def infer_dtype(self, x_type, min_type, max_type):
  535. if context.get_context('device_target') == "GPU":
  536. valid_dtypes = (mstype.float32,)
  537. else:
  538. valid_dtypes = (mstype.float16, mstype.float32)
  539. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  540. ("x", "min", "max"),
  541. (x_type, min_type, max_type)))
  542. return x_type
  543. class FakeQuantPerChannelGrad(PrimitiveWithInfer):
  544. r"""
  545. Performs grad of FakeQuantPerChannelGrad operation.
  546. Examples:
  547. >>> fqmmpc_grad = FakeQuantPerChannelGrad()
  548. >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
  549. >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
  550. >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
  551. >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
  552. >>> result = fqmmpc_grad(dout, input_x, _min, _max)
  553. """
  554. support_quant_bit = [4, 7, 8]
  555. @prim_attr_register
  556. def __init__(self,
  557. num_bits=8,
  558. quant_delay=0,
  559. symmetric=False,
  560. narrow_range=False,
  561. channel_axis=1):
  562. """Initialize FakeQuantPerChannelGrad Fill"""
  563. if context.get_context('device_target') == "Ascend":
  564. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
  565. if num_bits not in self.support_quant_bit:
  566. raise ValueError(
  567. f"For '{self.name}' attr \'num_bits\' is not support.")
  568. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  569. self.quant_delay = validator.check_value_type(
  570. 'quant_delay', quant_delay, (int,), self.name)
  571. self.symmetric = validator.check_value_type(
  572. 'symmetric', symmetric, (bool,), self.name)
  573. self.narrow_range = validator.check_value_type(
  574. 'narrow_range', narrow_range, (bool,), self.name)
  575. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  576. self.init_prim_io_names(
  577. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  578. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  579. validator.check("dout shape", dout_shape, "x shape", x_shape)
  580. validator.check("min shape", min_shape, "max shape", max_shape)
  581. return dout_shape
  582. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  583. if context.get_context('device_target') == "GPU":
  584. valid_dtypes = (mstype.float32,)
  585. else:
  586. valid_dtypes = (mstype.float16, mstype.float32)
  587. tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
  588. ("dout", "x", "min", "max"),
  589. (dout_type, x_type, min_type, max_type)))
  590. return dout_type
  591. class BatchNormFold(PrimitiveWithInfer):
  592. """
  593. Batch normalization folded.
  594. Args:
  595. momentum (float): Momentum value must be [0, 1]. Default: 0.9.
  596. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  597. float32 else 1e-3. Default: 1e-5.
  598. is_training (bool): In training mode set True, else set False. Default: True.
  599. freeze_bn (int): Delay in steps at which computation switches from regular batch
  600. norm to frozen mean and std. Default: 0.
  601. Inputs:
  602. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  603. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  604. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  605. - **global_step** (Tensor) - Tensor to record current global step.
  606. Outputs:
  607. Tuple of 4 Tensor, the normalized input and the updated parameters.
  608. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  609. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  610. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  611. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  612. Examples:
  613. >>> batch_norm_fold = P.BatchNormFold()
  614. >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
  615. >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
  616. >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
  617. >>> global_step = Tensor(np.arange(6), mindspore.int32)
  618. >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
  619. """
  620. channel_axis = 1
  621. @prim_attr_register
  622. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  623. """Initialize batch norm fold layer"""
  624. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  625. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  626. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  627. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  628. self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
  629. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
  630. def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
  631. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  632. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  633. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  634. return mean_shape, mean_shape, mean_shape, mean_shape
  635. def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
  636. validator.check("input type", x_type, "mean type", mean_type)
  637. validator.check("input type", x_type, "variance type", variance_type)
  638. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  639. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  640. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  641. return x_type, x_type, x_type, x_type
  642. class BatchNormFoldGrad(PrimitiveWithInfer):
  643. r"""
  644. Performs grad of BatchNormFold operation.
  645. Examples:
  646. >>> batch_norm_fold_grad = ops.BatchNormFoldGrad()
  647. >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
  648. >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
  649. >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
  650. >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
  651. >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
  652. >>> global_step = Tensor([2], mindspore.int32)
  653. >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
  654. """
  655. channel_axis = 1
  656. @prim_attr_register
  657. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  658. """Initialize BatchNormGrad layer"""
  659. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  660. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  661. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  662. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
  663. outputs=['dx'])
  664. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
  665. global_step_shape):
  666. validator.check("d_batch_mean shape", d_batch_mean_shape,
  667. "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
  668. validator.check("d_batch_mean shape", d_batch_mean_shape,
  669. "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  670. validator.check("d_batch_mean shape", d_batch_mean_shape,
  671. "batch_std shape", batch_std_shape, Rel.EQ, self.name)
  672. validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
  673. "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  674. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  675. return x_shape
  676. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
  677. global_step_type):
  678. args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
  679. "batch_mean": batch_mean_type, "batch_std": batch_std_type}
  680. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  681. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  682. return x_type
  683. class CorrectionMul(PrimitiveWithInfer):
  684. """
  685. Scales the weights with a correction factor to the long term statistics
  686. prior to quantization. This ensures that there is no jitter in the quantized weights
  687. due to batch to batch variation.
  688. Inputs:
  689. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  690. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  691. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  692. Outputs:
  693. - **out** (Tensor) - Tensor has the same shape as x.
  694. Examples:
  695. >>> correction_mul = ops.CorrectionMul()
  696. >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
  697. >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
  698. >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
  699. >>> out = correction_mul(input_x, batch_std, running_std)
  700. """
  701. @prim_attr_register
  702. def __init__(self, channel_axis=0):
  703. """Initialize correction mul layer"""
  704. if context.get_context('device_target') == "Ascend":
  705. from mindspore.ops._op_impl._custom_op import correction_mul
  706. self.channel_axis = channel_axis
  707. self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
  708. outputs=['out'])
  709. def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
  710. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  711. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  712. Rel.EQ, self.name)
  713. return x_shape
  714. def infer_dtype(self, x_type, batch_std_type, running_std_type):
  715. args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
  716. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  717. return x_type
  718. class CorrectionMulGrad(PrimitiveWithInfer):
  719. r"""
  720. Performs grad of CorrectionMul operation.
  721. Examples:
  722. >>> correction_mul_grad = ops.CorrectionMulGrad()
  723. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  724. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  725. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  726. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  727. >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
  728. """
  729. @prim_attr_register
  730. def __init__(self, channel_axis=0):
  731. """Initialize correction mul layer"""
  732. if context.get_context('device_target') == "Ascend":
  733. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  734. self.channel_axis = channel_axis
  735. self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
  736. outputs=['dx', 'mul_dx'])
  737. def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
  738. validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
  739. validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
  740. Rel.EQ, self.name)
  741. validator.check("running_std_shape[0]", running_std_shape[0],
  742. "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
  743. if context.get_context('device_target') == "Ascend":
  744. return x_shape, x_shape
  745. return x_shape, gamma_shape
  746. def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
  747. args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
  748. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  749. if context.get_context('device_target') == "Ascend":
  750. return x_type, x_type
  751. return x_type, gamma_type
  752. class CorrectionMulGradReduce(PrimitiveWithInfer):
  753. r"""
  754. Performs grad reduce of CorrectionMul operation.
  755. Examples:
  756. >>> correction_mul_grad_rd = ops.CorrectionMulGradReduce()
  757. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  758. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  759. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  760. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  761. >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
  762. """
  763. @prim_attr_register
  764. def __init__(self, channel_axis=0):
  765. """Initialize correction mul reduce layer"""
  766. if context.get_context('device_target') == "Ascend":
  767. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  768. self.channel_axis = channel_axis
  769. self.init_prim_io_names(inputs=['mul_dx'],
  770. outputs=['d_gamma'])
  771. def infer_shape(self, mul_dx_shape):
  772. return [mul_dx_shape[self.channel_axis]]
  773. def infer_dtype(self, mul_dx_type):
  774. return mul_dx_type
  775. class BatchNormFold2(PrimitiveWithInfer):
  776. """
  777. Scales the bias with a correction factor to the long term statistics
  778. prior to quantization. This ensures that there is no jitter in the quantized bias
  779. due to batch to batch variation.
  780. Inputs:
  781. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  782. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  783. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  784. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  785. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  786. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  787. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  788. - **global_step** (Tensor) - Tensor to record current global step.
  789. Outputs:
  790. - **y** (Tensor) - Tensor has the same shape as x.
  791. Examples:
  792. >>> batch_norm_fold2 = ops.BatchNormFold2()
  793. >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
  794. >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
  795. >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
  796. >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
  797. >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
  798. >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
  799. >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
  800. >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
  801. >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
  802. >>> running_std, running_mean, global_step)
  803. """
  804. channel_axis = 1
  805. @prim_attr_register
  806. def __init__(self, freeze_bn=0):
  807. """Initialize conv2d fold layer"""
  808. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  809. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
  810. 'running_std', 'running_mean', 'global_step'],
  811. outputs=['y'])
  812. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
  813. running_mean_shape, global_step_shape):
  814. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  815. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  816. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  817. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  818. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  819. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  820. Rel.EQ, self.name)
  821. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  822. return x_shape
  823. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
  824. running_mean_type, global_step_type):
  825. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  826. "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
  827. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  828. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  829. return x_type
  830. class BatchNormFold2Grad(PrimitiveWithInfer):
  831. r"""
  832. Performs grad of CorrectionAddGrad operation.
  833. Examples:
  834. >>> bnf2_grad = ops.BatchNormFold2Grad()
  835. >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
  836. >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
  837. >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
  838. >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
  839. >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
  840. >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
  841. >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
  842. >>> global_step = Tensor(np.array([-2]), mindspore.int32)
  843. >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
  844. """
  845. channel_axis = 1
  846. @prim_attr_register
  847. def __init__(self, freeze_bn=0):
  848. """Initialize MulFold layer"""
  849. self.freeze_bn = freeze_bn
  850. self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
  851. 'batch_std', 'batch_mean',
  852. 'running_std', 'running_mean', 'global_step'],
  853. outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
  854. def infer_shape(self, dout_shape, x_shape, gamma_shape,
  855. batch_std_shape, batch_mean_shape,
  856. running_std_shape, running_mean_shape, global_step_shape):
  857. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  858. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  859. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  860. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  861. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  862. Rel.EQ, self.name)
  863. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  864. return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
  865. def infer_dtype(self, dout_type, x_type, gamma_type,
  866. batch_std_type, batch_mean_type,
  867. running_std_type, running_mean_type, global_step_type):
  868. validator.check("batch_std type", batch_std_type,
  869. "batch_mean type", batch_mean_type)
  870. validator.check("batch_std type", batch_std_type,
  871. "gamma type", gamma_type)
  872. validator.check("batch_std type", batch_std_type,
  873. "running_std type", running_std_type)
  874. validator.check("batch_std type", batch_std_type,
  875. "running_mean type", running_mean_type)
  876. validator.check("batch_std_type", batch_std_type,
  877. "dout type", dout_type)
  878. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  879. "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
  880. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  881. validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
  882. return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
  883. class BatchNormFoldD(PrimitiveWithInfer):
  884. """Performs grad of _BatchNormFold operation."""
  885. @prim_attr_register
  886. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  887. """Initialize _BatchNormFold layer"""
  888. from mindspore.ops._op_impl._custom_op import batchnorm_fold
  889. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  890. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  891. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  892. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  893. self.data_format = "NCHW"
  894. self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
  895. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
  896. 'mean_updated', 'variance_updated'])
  897. def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
  898. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  899. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
  900. return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
  901. def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
  902. validator.check("input type", x_type, "mean type", mean_type)
  903. validator.check("input type", x_type, "variance type", variance_type)
  904. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  905. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  906. return x_type, x_type, x_type, x_type, x_type, x_type, x_type
  907. class BatchNormFoldGradD(PrimitiveWithInfer):
  908. """Performs grad of _BatchNormFoldGrad operation."""
  909. @prim_attr_register
  910. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  911. """Initialize _BatchNormFoldGrad layer"""
  912. from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
  913. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  914. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  915. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  916. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
  917. outputs=['dx'])
  918. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
  919. validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
  920. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
  921. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
  922. validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
  923. return x_shape
  924. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
  925. validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
  926. validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
  927. validator.check("input type", x_type, "batch_mean type", batch_mean_type)
  928. validator.check("input type", x_type, "batch_std type", batch_std_type)
  929. validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
  930. return x_type
  931. class BatchNormFold2_D(PrimitiveWithInfer):
  932. """
  933. Scales the bias with a correction factor to the long term statistics
  934. prior to quantization. This ensures that there is no jitter in the quantized bias
  935. due to batch to batch variation.
  936. Inputs:
  937. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  938. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  939. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  940. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  941. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  942. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  943. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  944. - **global_step** (Tensor) - Tensor to record current global step.
  945. Outputs:
  946. - **y** (Tensor) - Tensor has the same shape as x.
  947. """
  948. channel_axis = 1
  949. @prim_attr_register
  950. def __init__(self, freeze_bn=0):
  951. """Initialize conv2d fold layer"""
  952. from mindspore.ops._op_impl._custom_op import batchnorm_fold2
  953. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  954. outputs=['y'])
  955. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
  956. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  957. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  958. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  959. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  960. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  961. Rel.EQ, self.name)
  962. return x_shape
  963. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
  964. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  965. "beta": beta_type, "gamma": gamma_type, "x": x_type}
  966. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  967. return x_type
  968. class BatchNormFold2GradD(PrimitiveWithInfer):
  969. """Performs grad of CorrectionAddGrad operation."""
  970. channel_axis = 1
  971. @prim_attr_register
  972. def __init__(self, freeze_bn=False):
  973. """Initialize MulFold layer"""
  974. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
  975. self.freeze_bn = freeze_bn
  976. self.init_prim_io_names(
  977. inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  978. outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
  979. def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
  980. batch_mean_shape, running_std_shape):
  981. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  982. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  983. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  984. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  985. Rel.EQ, self.name)
  986. return gamma_shape, gamma_shape, gamma_shape, dout_shape
  987. def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
  988. batch_mean_type, running_std_type):
  989. validator.check("batch_std type", batch_std_type,
  990. "batch_mean type", batch_mean_type)
  991. validator.check("batch_std type", batch_std_type,
  992. "gamma type", gamma_type)
  993. validator.check("batch_std type", batch_std_type,
  994. "running_std type", running_std_type)
  995. validator.check("batch_std_type", batch_std_type,
  996. "dout type", dout_type)
  997. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  998. "running_std": running_std_type, "dout": dout_type}
  999. validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
  1000. return gamma_type, gamma_type, gamma_type, gamma_type
  1001. class BatchNormFold2GradReduce(PrimitiveWithInfer):
  1002. """Performs grad of CorrectionAddGrad operation."""
  1003. channel_axis = 1
  1004. @prim_attr_register
  1005. def __init__(self, freeze_bn=False):
  1006. """Initialize MulFold layer"""
  1007. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
  1008. self.freeze_bn = freeze_bn
  1009. self.init_prim_io_names(inputs=['dout', 'x'],
  1010. outputs=['dout_reduce', 'dout_x_reduce'])
  1011. def infer_shape(self, dout_shape, x_shape):
  1012. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  1013. return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
  1014. def infer_dtype(self, dout_type, x_type):
  1015. validator.check("dout type", dout_type, "x type", x_type)
  1016. return dout_type, dout_type
  1017. class ActsULQ(PrimitiveWithInfer):
  1018. """
  1019. The ActsULQ(Activation universal learnable quantization).
  1020. Args:
  1021. fixed_min (bool): whether fix clamp min to zero.
  1022. num_bits (int): The bits num used for quantize.
  1023. Inputs:
  1024. - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
  1025. - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x.
  1026. - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x.
  1027. Outputs:
  1028. - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`.
  1029. - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min.
  1030. - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max.
  1031. - **x_clamped_loss** (Tensor) - A tensor of clamped loss.
  1032. Examples:
  1033. >>> data_type = np.float32
  1034. >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1035. >>> clamp_max = 0.7 * np.max(x)
  1036. >>> clamp_min = 0.7 * np.min(x)
  1037. >>> clamp_max = np.array([clamp_max], dtype=data_type)
  1038. >>> clamp_min = np.array([clamp_min], dtype=data_type)
  1039. >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8)
  1040. >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min),
  1041. Tensor(clamp_max))
  1042. """
  1043. @prim_attr_register
  1044. def __init__(self, fixed_min=False, num_bits=8):
  1045. validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
  1046. validator.check_value_type("num_bits", num_bits, [int], self.name)
  1047. validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
  1048. def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
  1049. """infer shape of primitive"""
  1050. validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
  1051. validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
  1052. x_shape_len = len(x_shape)
  1053. for i in range(x_shape_len):
  1054. validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
  1055. validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
  1056. return x_shape, x_shape, x_shape, x_shape
  1057. def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
  1058. """infer dtype of primitive"""
  1059. valid_types = [mstype.float32, mstype.float16]
  1060. validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
  1061. validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
  1062. validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)
  1063. return x_dtype, mstype.bool_, mstype.bool_, x_dtype
  1064. class ActsULQInputGrad(PrimitiveWithInfer):
  1065. """
  1066. The ActsULQInputGrad(grad of ActsULQ).
  1067. Inputs:
  1068. - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type.
  1069. Outputs:
  1070. - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`.
  1071. """
  1072. @prim_attr_register
  1073. def __init__(self):
  1074. pass
  1075. def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape):
  1076. return y_grad_shape
  1077. def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
  1078. valid_types = [mstype.float32, mstype.float16]
  1079. validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
  1080. return y_grad_type
  1081. class ActULQClampMinGrad(PrimitiveWithInfer):
  1082. """
  1083. The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient)
  1084. Inputs:
  1085. - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
  1086. - **clamp_min_mask** - A tensor of mask, only support int8 type.
  1087. - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
  1088. Outputs:
  1089. - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad".
  1090. The length of tensor is 1.
  1091. Examples:
  1092. >>> data_type = np.float32
  1093. >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1094. >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
  1095. >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1096. >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad()
  1097. >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_),
  1098. Tensor(x_clamped_loss))
  1099. """
  1100. @prim_attr_register
  1101. def __init__(self):
  1102. pass
  1103. def infer_shape(self, input_x, input_y, input_z):
  1104. input_x_len = len(input_x)
  1105. output_shape = []
  1106. for _ in range(input_x_len):
  1107. output_shape.append(1)
  1108. return tuple(output_shape)
  1109. def infer_dtype(self, input_x, input_y, input_z):
  1110. return mstype.float32
  1111. class ActULQClampMaxGrad(PrimitiveWithInfer):
  1112. """
  1113. The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient)
  1114. Inputs:
  1115. - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
  1116. - **clamp_max_mask** - A tensor of mask, only support int8 type.
  1117. - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
  1118. Outputs:
  1119. - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad".
  1120. The length of tensor is 1.
  1121. Examples:
  1122. >>> data_type = np.float32
  1123. >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1124. >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
  1125. >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
  1126. >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad()
  1127. >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_),
  1128. Tensor(x_clamped_loss))
  1129. """
  1130. @prim_attr_register
  1131. def __init__(self):
  1132. pass
  1133. def infer_shape(self, input_x, input_y, input_z):
  1134. input_x_len = len(input_x)
  1135. output_shape = []
  1136. for _ in range(input_x_len):
  1137. output_shape.append(1)
  1138. return tuple(output_shape)
  1139. def infer_dtype(self, input_x, input_y, input_z):
  1140. return mstype.float32
  1141. class WtsARQ(PrimitiveWithInfer):
  1142. """
  1143. The WtsARQ(Weights Adaptive Range Quantization).
  1144. Args:
  1145. axes (list): Specify channels for ARQ algorithm.
  1146. num_bits (int): The bits num used for quantize.
  1147. offset_flag (bool): Whether use offset for quantize.
  1148. Inputs:
  1149. - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type.
  1150. Outputs:
  1151. - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`.
  1152. - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`.
  1153. - If axis is [],
  1154. the shape of scale and offset is :math:`(1, )`.
  1155. - If axis is [0],
  1156. the shape of scale and offset is :math:`(w_1, )`.
  1157. - If axis is [1],
  1158. the shape of scale and offset is :math:`(w_2, )`.
  1159. - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`.
  1160. Examples:
  1161. >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
  1162. >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False)
  1163. >>> scale, offset, y = wts_arq(data)
  1164. """
  1165. @prim_attr_register
  1166. def __init__(self, num_bits, offset_flag):
  1167. validator.check_value_type("num_bits", num_bits, [int], self.name)
  1168. validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
  1169. validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
  1170. def infer_shape(self, w_shape, w_min_shape, w_max_shape):
  1171. validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
  1172. validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
  1173. return w_shape
  1174. def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
  1175. valid_types = [mstype.float32, mstype.float16]
  1176. validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
  1177. validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
  1178. validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
  1179. return w_dtype
  1180. class IFMR(PrimitiveWithInfer):
  1181. """
  1182. The TFMR(Input Feature Map Reconstruction).
  1183. Args:
  1184. min_percentile (float): Min init percentile. Default: 0.999999.
  1185. max_percentile (float): Max init percentile. Default: 0.999999.
  1186. search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
  1187. search_step (float): Step size of searching. Default: 0.01.
  1188. with_offset (bool): Whether using offset. Default: True.
  1189. Inputs:
  1190. - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
  1191. - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
  1192. With float16 or float32 data type.
  1193. - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
  1194. With float16 or float32 data type.
  1195. - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
  1196. Outputs:
  1197. - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
  1198. - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
  1199. Examples:
  1200. >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
  1201. >>> data_min = Tensor([0.1], mindspore.float32)
  1202. >>> data_max = Tensor([0.5], mindspore.float32)
  1203. >>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
  1204. >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
  1205. ... search_step=1.0, with_offset=False)
  1206. >>> output = ifmr(data, data_min, data_max, cumsum)
  1207. >>> print(output)
  1208. (Tensor(shape=[1], dtype=Float32, value= [7.87401572e-03]),
  1209. Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
  1210. """
  1211. @prim_attr_register
  1212. def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
  1213. with_offset=True):
  1214. validator.check_value_type("min_percentile", min_percentile, [float], self.name)
  1215. validator.check_value_type("max_percentile", max_percentile, [float], self.name)
  1216. validator.check_value_type("search_range", search_range, [list, tuple], self.name)
  1217. for item in search_range:
  1218. validator.check_positive_float(item, "item of search_range", self.name)
  1219. validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
  1220. validator.check_value_type("search_step", search_step, [float], self.name)
  1221. validator.check_value_type("offset_flag", with_offset, [bool], self.name)
  1222. def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
  1223. validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
  1224. validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
  1225. validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
  1226. validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
  1227. validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
  1228. return (1,), (1,)
  1229. def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
  1230. tuple(map(partial(validator.check_tensor_dtype_valid,
  1231. valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
  1232. ("input_value", "input_min", "input_max"),
  1233. (data_dtype, data_min_dtype, data_max_dtype)))
  1234. validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
  1235. return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)