You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_quant_ops.py 62 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0(the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http: // www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for quantization."""
  16. import mindspore.context as context
  17. from ..._checkparam import Validator as validator
  18. from ..._checkparam import Rel
  19. from ..primitive import PrimitiveWithInfer, prim_attr_register
  20. from ...common import dtype as mstype
  21. __all__ = ["MinMaxUpdatePerLayer",
  22. "MinMaxUpdatePerChannel",
  23. "FakeQuantWithMinMaxVars",
  24. "FakeQuantWithMinMaxVarsGradient",
  25. "FakeQuantWithMinMaxVarsPerChannel",
  26. "FakeQuantWithMinMaxVarsPerChannelGradient",
  27. "FakeQuantPerLayer",
  28. "FakeQuantPerLayerGrad",
  29. "FakeQuantPerChannel",
  30. "FakeQuantPerChannelGrad",
  31. "BatchNormFold",
  32. "BatchNormFoldGrad",
  33. "CorrectionMul",
  34. "CorrectionMulGrad",
  35. "CorrectionMulGradReduce",
  36. "BatchNormFold2",
  37. "BatchNormFold2Grad",
  38. "BatchNormFoldD",
  39. "BatchNormFoldGradD",
  40. "BatchNormFold2_D",
  41. "BatchNormFold2GradD",
  42. "BatchNormFold2GradReduce"
  43. ]
  44. class MinMaxUpdatePerLayer(PrimitiveWithInfer):
  45. r"""
  46. Updates min and max per layer.
  47. Args:
  48. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  49. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  50. Inputs:
  51. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  52. - **min** (Tensor) : Value of the min range of the input data x.
  53. - **max** (Tensor) : Value of the max range of the input data x.
  54. Outputs:
  55. - Tensor: Simulates quantize tensor of x.
  56. Examples:
  57. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  58. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  59. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  60. >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  61. """
  62. support_quant_bit = [4, 7, 8]
  63. @prim_attr_register
  64. def __init__(self, ema=False, ema_decay=0.999):
  65. """Initialize FakeQuantMinMaxPerLayerUpdate OP"""
  66. if context.get_context('device_target') == "Ascend":
  67. from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
  68. if ema and not ema_decay:
  69. raise ValueError(
  70. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  71. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  72. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  73. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  74. outputs=['min_up', 'max_up'])
  75. def infer_shape(self, x_shape, min_shape, max_shape):
  76. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  77. validator.check("min shape", min_shape, "max shape",
  78. max_shape, Rel.EQ, self.name)
  79. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  80. return min_shape, max_shape
  81. def infer_dtype(self, x_type, min_type, max_type):
  82. valid_types = (mstype.float16, mstype.float32)
  83. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  84. validator.check_tensor_type_same(
  85. {"min": min_type}, valid_types, self.name)
  86. validator.check_tensor_type_same(
  87. {"max": max_type}, valid_types, self.name)
  88. return min_type, max_type
  89. class MinMaxUpdatePerChannel(PrimitiveWithInfer):
  90. r"""
  91. Updates min and max per channel.
  92. Args:
  93. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  94. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  95. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  96. Inputs:
  97. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  98. - **min** (Tensor) : Value of the min range of the input data x.
  99. - **max** (Tensor) : Value of the max range of the input data x.
  100. Outputs:
  101. - Tensor: Simulates quantize tensor of x.
  102. Examples:
  103. >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  104. >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  105. >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
  106. >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
  107. """
  108. support_quant_bit = [4, 7, 8]
  109. ascend_support_x_rank = [2, 4]
  110. @prim_attr_register
  111. def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
  112. """Initialize FakeQuantPerChannelUpdate OP for Ascend"""
  113. self.is_ascend = context.get_context('device_target') == "Ascend"
  114. if self.is_ascend:
  115. from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
  116. if ema and not ema_decay:
  117. raise ValueError(
  118. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  119. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  120. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  121. if self.is_ascend:
  122. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  123. else:
  124. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  125. self.init_prim_io_names(
  126. inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
  127. def infer_shape(self, x_shape, min_shape, max_shape):
  128. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  129. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  130. if not self.is_ascend:
  131. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  132. validator.check("min shape", min_shape, "max shape",
  133. max_shape, Rel.EQ, self.name)
  134. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  135. return min_shape, max_shape
  136. def infer_dtype(self, x_type, min_type, max_type):
  137. valid_types = (mstype.float16, mstype.float32)
  138. validator.check_tensor_type_same(
  139. {"x": x_type}, valid_types, self.name)
  140. validator.check_tensor_type_same(
  141. {"min": min_type}, valid_types, self.name)
  142. validator.check_tensor_type_same(
  143. {"max": max_type}, valid_types, self.name)
  144. return min_type, max_type
  145. class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
  146. r"""
  147. Fake-quantize the input by min and max.
  148. Args:
  149. num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
  150. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  151. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  152. range is [1, 2^num_bits-1]. Default: False.
  153. Inputs:
  154. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  155. - **min** (Tensor) - Value of the min range of the input data x.
  156. - **max** (Tensor) - Value of the max range of the input data x.
  157. Outputs:
  158. - Tensor, the data type and shape of output tensor is the same as input x.
  159. Examples:
  160. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  161. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  162. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  163. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  164. >>> input_tensor, min_tensor, max_tensor)
  165. >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
  166. """
  167. @prim_attr_register
  168. def __init__(self,
  169. num_bits=8,
  170. narrow_range=False):
  171. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  172. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  173. self.narrow_range = validator.check_value_type(
  174. 'narrow_range', narrow_range, (bool,), self.name)
  175. def check_broadcast(self, min_shape, input_shape):
  176. shape_val = 1
  177. for shape in input_shape:
  178. shape_val = shape_val * shape
  179. if min_shape[0] > 1 and min_shape[0] != shape_val:
  180. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  181. def infer_shape(self, x_shape, min_shape, max_shape):
  182. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  183. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  184. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  185. self.check_broadcast(min_shape, x_shape)
  186. return x_shape
  187. def infer_dtype(self, x_type, min_type, max_type):
  188. valid_types = (mstype.float16, mstype.float32)
  189. validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
  190. validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
  191. validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
  192. return x_type
  193. class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
  194. r"""
  195. Performs grad of FakeQuantWithMinMaxVars operation.
  196. Args:
  197. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  198. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  199. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  200. range is [1, 2^num_bits-1]. Default: False.
  201. Inputs:
  202. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  203. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  204. - **min** (Tensor) - Value of the min range of the input data x.
  205. - **max** (Tensor) - Value of the max range of the input data x.
  206. Outputs:
  207. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  208. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  209. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  210. Examples:
  211. >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  212. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  213. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  214. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  215. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
  216. >>> (gradients, input_tensor, min_tensor, max_tensor)
  217. >>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
  218. >>> min_gradient shape: (1,) data type: mstype.float32
  219. >>> max_gradient shape: (1,) data type: mstype.float32
  220. """
  221. @prim_attr_register
  222. def __init__(self,
  223. num_bits=8,
  224. narrow_range=False):
  225. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  226. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  227. self.narrow_range = validator.check_value_type(
  228. 'narrow_range', narrow_range, (bool,), self.name)
  229. def check_broadcast(self, min_shape, input_shape):
  230. shape_val = 1
  231. for shape in input_shape:
  232. shape_val = shape_val * shape
  233. if min_shape[0] > 1 and min_shape[0] != shape_val:
  234. raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
  235. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  236. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  237. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  238. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  239. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  240. self.check_broadcast(min_shape, x_shape)
  241. return x_shape, min_shape, max_shape
  242. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  243. valid_types = (mstype.float16, mstype.float32)
  244. validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
  245. validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
  246. validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
  247. validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
  248. return x_type, min_type, max_type
  249. class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
  250. r"""
  251. Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
  252. Args:
  253. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  254. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  255. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  256. range is [1, 2^num_bits-1]. Default: False.
  257. Inputs:
  258. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  259. - **min** (Tensor) - Value of the min range of the input data x.
  260. - **max** (Tensor) - Value of the max range of the input data x.
  261. Outputs:
  262. - Tensor, the data type and shape of output tensor is the same as input x.
  263. Examples:
  264. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  265. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  266. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  267. >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
  268. >>> input_tensor, min_tensor, max_tensor)
  269. >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
  270. """
  271. @prim_attr_register
  272. def __init__(self,
  273. num_bits=8,
  274. narrow_range=False):
  275. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  276. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  277. self.narrow_range = validator.check_value_type(
  278. 'narrow_range', narrow_range, (bool,), self.name)
  279. def infer_shape(self, x_shape, min_shape, max_shape):
  280. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  281. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  282. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  283. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  284. return x_shape
  285. def infer_dtype(self, x_type, min_type, max_type):
  286. valid_types = (mstype.float16, mstype.float32)
  287. validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
  288. validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
  289. validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
  290. return x_type
  291. class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
  292. r"""
  293. Performs grad of FakeQuantWithMinMaxVars operation.
  294. Args:
  295. num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
  296. narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
  297. if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
  298. range is [1, 2^num_bits-1]. Default: False.
  299. Inputs:
  300. - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
  301. - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
  302. - **min** (Tensor) - Value of the min range of the input data x.
  303. - **max** (Tensor) - Value of the max range of the input data x.
  304. Outputs:
  305. - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
  306. - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
  307. - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
  308. Examples:
  309. >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  310. >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
  311. >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
  312. >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
  313. >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
  314. >>> num_bits=8, narrow_range=False)(
  315. >>> gradients, input_tensor, min_tensor, max_tensor)
  316. >>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
  317. >>> min_gradient shape: (4,) data type: mstype.float32
  318. >>> max_gradient shape: (4,) data type: mstype.float32
  319. """
  320. @prim_attr_register
  321. def __init__(self,
  322. num_bits=8,
  323. narrow_range=False):
  324. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  325. self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
  326. self.narrow_range = validator.check_value_type(
  327. 'narrow_range', narrow_range, (bool,), self.name)
  328. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  329. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  330. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  331. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  332. validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
  333. validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
  334. return x_shape, min_shape, max_shape
  335. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  336. valid_types = (mstype.float16, mstype.float32)
  337. validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
  338. validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
  339. validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
  340. validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
  341. return x_type, min_type, max_type
  342. class FakeQuantPerLayer(PrimitiveWithInfer):
  343. r"""
  344. Simulates the quantize and dequantize operations in training time.
  345. Args:
  346. num_bits (int) : Number bits for quantization aware. Default: 8.
  347. ema (bool): Uses EMA algorithm update value min and max. Default: False.
  348. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  349. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  350. simulate quantization aware funcion. After delay step in training time begin simulate the aware
  351. quantize funcion. Default: 0.
  352. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  353. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  354. training (bool): Training the network or not. Default: True.
  355. Inputs:
  356. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  357. - **min** (Tensor) : Value of the min range of the input data x.
  358. - **max** (Tensor) : Value of the max range of the input data x.
  359. Outputs:
  360. - Tensor: Simulates quantize tensor of x.
  361. Examples:
  362. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  363. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  364. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  365. >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
  366. """
  367. support_quant_bit = [4, 7, 8]
  368. @prim_attr_register
  369. def __init__(self,
  370. num_bits=8,
  371. ema=False,
  372. ema_decay=0.999,
  373. quant_delay=0,
  374. symmetric=False,
  375. narrow_range=False,
  376. training=True):
  377. """Initialize FakeQuantPerLayer OP"""
  378. if context.get_context('device_target') == "Ascend":
  379. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
  380. if num_bits not in self.support_quant_bit:
  381. raise ValueError(
  382. f"For '{self.name}' attr \'num_bits\' is not support.")
  383. if ema and not ema_decay:
  384. raise ValueError(
  385. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  386. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  387. self.symmetric = validator.check_value_type(
  388. 'symmetric', symmetric, (bool,), self.name)
  389. self.narrow_range = validator.check_value_type(
  390. 'narrow_range', narrow_range, (bool,), self.name)
  391. self.training = validator.check_value_type('training', training, (bool,), self.name)
  392. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  393. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  394. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  395. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  396. outputs=['out'])
  397. def infer_shape(self, x_shape, min_shape, max_shape):
  398. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  399. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  400. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  401. return x_shape
  402. def infer_dtype(self, x_type, min_type, max_type):
  403. valid_types = (mstype.float16, mstype.float32)
  404. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  405. validator.check_tensor_type_same(
  406. {"min": min_type}, valid_types, self.name)
  407. validator.check_tensor_type_same(
  408. {"max": max_type}, valid_types, self.name)
  409. return x_type
  410. class FakeQuantPerLayerGrad(PrimitiveWithInfer):
  411. r"""
  412. Performs grad of FakeQuantPerLayerGrad operation.
  413. Examples:
  414. >>> fake_min_max_grad = FakeQuantPerLayerGrad()
  415. >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
  416. >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
  417. >>> _min = Tensor(np.array([-4]), mindspore.float32)
  418. >>> _max = Tensor(np.array([2]), mindspore.float32)
  419. >>> result = fake_min_max_grad(dout, input_x, _min, _max)
  420. """
  421. support_quant_bit = [4, 7, 8]
  422. @prim_attr_register
  423. def __init__(self,
  424. num_bits=8,
  425. quant_delay=0,
  426. symmetric=False,
  427. narrow_range=False):
  428. if context.get_context('device_target') == "Ascend":
  429. from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
  430. if num_bits not in self.support_quant_bit:
  431. raise ValueError(
  432. f"For '{self.name}' attr \'num_bits\' is not support.")
  433. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  434. self.quant_delay = validator.check_value_type(
  435. 'quant_delay', quant_delay, (int,), self.name)
  436. self.symmetric = validator.check_value_type(
  437. 'symmetric', symmetric, (bool,), self.name)
  438. self.narrow_range = validator.check_value_type(
  439. 'narrow_range', narrow_range, (bool,), self.name)
  440. self.init_prim_io_names(
  441. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  442. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  443. validator.check("dout shape", dout_shape, "x shape",
  444. x_shape, Rel.EQ, self.name)
  445. validator.check("min shape", min_shape, "max shape",
  446. max_shape, Rel.EQ, self.name)
  447. validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
  448. return dout_shape
  449. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  450. valid_types = (mstype.float16, mstype.float32)
  451. validator.check_tensor_type_same(
  452. {"dout": dout_type}, valid_types, self.name)
  453. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  454. validator.check_tensor_type_same(
  455. {"min": min_type}, valid_types, self.name)
  456. validator.check_tensor_type_same(
  457. {"max": max_type}, valid_types, self.name)
  458. return dout_type
  459. class FakeQuantPerChannel(PrimitiveWithInfer):
  460. r"""
  461. Simulates the quantize and dequantize operations in training time base on per channel.
  462. Args:
  463. num_bits (int) : Number bits to quantilization. Default: 8.
  464. ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: False.
  465. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  466. quant_delay (int): Quantilization delay parameter. Before delay step in training time not
  467. update the weight data to simulate quantize operation. After delay step in training time
  468. begin simulate the quantize operation. Default: 0.
  469. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
  470. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
  471. training (bool): Training the network or not. Default: True.
  472. channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
  473. Inputs:
  474. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
  475. - **min** (int, float) : Value of the min range of the input data.
  476. - **max** (int, float) : Value of the max range of the input data.
  477. Outputs:
  478. - Tensor, has the same type as input.
  479. Examples:
  480. >>> fake_quant = FakeQuantPerChannel()
  481. >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
  482. >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
  483. >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
  484. >>> result = fake_quant(input_x, _min, _max)
  485. """
  486. support_quant_bit = [4, 7, 8]
  487. ascend_support_x_rank = [2, 4]
  488. @prim_attr_register
  489. def __init__(self,
  490. num_bits=8,
  491. ema=False,
  492. ema_decay=0.999,
  493. quant_delay=0,
  494. symmetric=False,
  495. narrow_range=False,
  496. training=True,
  497. channel_axis=1):
  498. """Initialize FakeQuantPerChannel OP"""
  499. self.is_ascend = context.get_context('device_target') == "Ascend"
  500. if self.is_ascend:
  501. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
  502. if num_bits not in self.support_quant_bit:
  503. raise ValueError(
  504. f"For '{self.name}' Attr \'num_bits\' is not support.")
  505. if ema and not ema_decay:
  506. raise ValueError(
  507. f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  508. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  509. self.symmetric = validator.check_value_type(
  510. 'symmetric', symmetric, (bool,), self.name)
  511. self.narrow_range = validator.check_value_type(
  512. 'narrow_range', narrow_range, (bool,), self.name)
  513. self.training = validator.check_value_type(
  514. 'training', training, (bool,), self.name)
  515. self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
  516. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  517. self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
  518. if self.is_ascend:
  519. self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
  520. else:
  521. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
  522. self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
  523. def infer_shape(self, x_shape, min_shape, max_shape):
  524. if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
  525. raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
  526. if not self.is_ascend:
  527. validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
  528. if len(x_shape) == 1:
  529. self.channel_axis = 0
  530. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  531. validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
  532. validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
  533. return x_shape
  534. def infer_dtype(self, x_type, min_type, max_type):
  535. valid_types = (mstype.float16, mstype.float32)
  536. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  537. validator.check_tensor_type_same(
  538. {"min": min_type}, valid_types, self.name)
  539. validator.check_tensor_type_same(
  540. {"max": max_type}, valid_types, self.name)
  541. return x_type
  542. class FakeQuantPerChannelGrad(PrimitiveWithInfer):
  543. r"""
  544. Performs grad of FakeQuantPerChannelGrad operation.
  545. Examples:
  546. >>> fqmmpc_grad = FakeQuantPerChannelGrad()
  547. >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
  548. >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
  549. >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
  550. >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
  551. >>> result = fqmmpc_grad(dout, input_x, _min, _max)
  552. """
  553. support_quant_bit = [4, 7, 8]
  554. @prim_attr_register
  555. def __init__(self,
  556. num_bits=8,
  557. quant_delay=0,
  558. symmetric=False,
  559. narrow_range=False,
  560. channel_axis=1):
  561. """Initialize FakeQuantPerChannelGrad Fill"""
  562. if context.get_context('device_target') == "Ascend":
  563. from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
  564. if num_bits not in self.support_quant_bit:
  565. raise ValueError(
  566. f"For '{self.name}' attr \'num_bits\' is not support.")
  567. self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
  568. self.quant_delay = validator.check_value_type(
  569. 'quant_delay', quant_delay, (int,), self.name)
  570. self.symmetric = validator.check_value_type(
  571. 'symmetric', symmetric, (bool,), self.name)
  572. self.narrow_range = validator.check_value_type(
  573. 'narrow_range', narrow_range, (bool,), self.name)
  574. self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
  575. self.init_prim_io_names(
  576. inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  577. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  578. validator.check("dout shape", dout_shape, "x shape", x_shape)
  579. validator.check("min shape", min_shape, "max shape", max_shape)
  580. return dout_shape
  581. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  582. valid_types = (mstype.float16, mstype.float32)
  583. validator.check_tensor_type_same(
  584. {"dout": dout_type}, valid_types, self.name)
  585. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  586. validator.check_tensor_type_same(
  587. {"min": min_type}, valid_types, self.name)
  588. validator.check_tensor_type_same(
  589. {"max": max_type}, valid_types, self.name)
  590. return dout_type
  591. class BatchNormFold(PrimitiveWithInfer):
  592. """
  593. Batch normalization folded.
  594. Args:
  595. momentum (float): Momentum value must be [0, 1]. Default: 0.9.
  596. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  597. float32 else 1e-3. Default: 1e-5.
  598. is_training (bool): In training mode set True, else set False. Default: True.
  599. freeze_bn (int): Delay in steps at which computation switches from regular batch
  600. norm to frozen mean and std. Default: 0.
  601. Inputs:
  602. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  603. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  604. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  605. - **global_step** (Tensor) - Tensor to record current global step.
  606. Outputs:
  607. Tuple of 4 Tensor, the normalized input and the updated parameters.
  608. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  609. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  610. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  611. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  612. Examples:
  613. >>> batch_norm_fold = P.BatchNormFold()
  614. >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
  615. >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
  616. >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
  617. >>> global_step = Tensor(np.arange(6), mindspore.int32)
  618. >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
  619. """
  620. channel_axis = 1
  621. @prim_attr_register
  622. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  623. """Initialize batch norm fold layer"""
  624. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  625. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  626. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  627. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  628. self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
  629. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
  630. def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
  631. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  632. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  633. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  634. return mean_shape, mean_shape, mean_shape, mean_shape
  635. def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
  636. validator.check("input type", x_type, "mean type", mean_type)
  637. validator.check("input type", x_type, "variance type", variance_type)
  638. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  639. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  640. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  641. return x_type, x_type, x_type, x_type
  642. class BatchNormFoldGrad(PrimitiveWithInfer):
  643. r"""
  644. Performs grad of BatchNormFold operation.
  645. Examples:
  646. >>> batch_norm_fold_grad = P.BatchNormFoldGrad()
  647. >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
  648. >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
  649. >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
  650. >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
  651. >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
  652. >>> global_step = Tensor([2], mindspore.int32)
  653. >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
  654. """
  655. channel_axis = 1
  656. @prim_attr_register
  657. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  658. """Initialize BatchNormGrad layer"""
  659. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  660. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  661. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  662. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
  663. outputs=['dx'])
  664. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
  665. global_step_shape):
  666. validator.check("d_batch_mean shape", d_batch_mean_shape,
  667. "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
  668. validator.check("d_batch_mean shape", d_batch_mean_shape,
  669. "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  670. validator.check("d_batch_mean shape", d_batch_mean_shape,
  671. "batch_std shape", batch_std_shape, Rel.EQ, self.name)
  672. validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
  673. "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
  674. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  675. return x_shape
  676. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
  677. global_step_type):
  678. args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
  679. "batch_mean": batch_mean_type, "batch_std": batch_std_type}
  680. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  681. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  682. return x_type
  683. class CorrectionMul(PrimitiveWithInfer):
  684. """
  685. Scales the weights with a correction factor to the long term statistics
  686. prior to quantization. This ensures that there is no jitter in the quantized weights
  687. due to batch to batch variation.
  688. Inputs:
  689. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  690. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  691. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  692. Outputs:
  693. - **out** (Tensor) - Tensor has the same shape as x.
  694. Examples:
  695. >>> correction_mul = P.CorrectionMul()
  696. >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
  697. >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
  698. >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
  699. >>> out = correction_mul(input_x, batch_std, running_std)
  700. """
  701. @prim_attr_register
  702. def __init__(self, channel_axis=0):
  703. """Initialize correction mul layer"""
  704. if context.get_context('device_target') == "Ascend":
  705. from mindspore.ops._op_impl._custom_op import correction_mul
  706. self.channel_axis = channel_axis
  707. self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
  708. outputs=['out'])
  709. def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
  710. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  711. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  712. Rel.EQ, self.name)
  713. return x_shape
  714. def infer_dtype(self, x_type, batch_std_type, running_std_type):
  715. args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
  716. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  717. return x_type
  718. class CorrectionMulGrad(PrimitiveWithInfer):
  719. r"""
  720. Performs grad of CorrectionMul operation.
  721. Examples:
  722. >>> correction_mul_grad = P.CorrectionMulGrad()
  723. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  724. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  725. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  726. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  727. >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
  728. """
  729. @prim_attr_register
  730. def __init__(self, channel_axis=0):
  731. """Initialize correction mul layer"""
  732. if context.get_context('device_target') == "Ascend":
  733. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  734. self.channel_axis = channel_axis
  735. self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
  736. outputs=['dx', 'mul_dx'])
  737. def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
  738. validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
  739. validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
  740. Rel.EQ, self.name)
  741. validator.check("running_std_shape[0]", running_std_shape[0],
  742. "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
  743. if context.get_context('device_target') == "Ascend":
  744. return x_shape, x_shape
  745. return x_shape, gamma_shape
  746. def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
  747. args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
  748. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  749. if context.get_context('device_target') == "Ascend":
  750. return x_type, x_type
  751. return x_type, gamma_type
  752. class CorrectionMulGradReduce(PrimitiveWithInfer):
  753. r"""
  754. Performs grad reduce of CorrectionMul operation.
  755. Examples:
  756. >>> correction_mul_grad_rd = P.CorrectionMulGradReduce()
  757. >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
  758. >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
  759. >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
  760. >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
  761. >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
  762. """
  763. @prim_attr_register
  764. def __init__(self, channel_axis=0):
  765. """Initialize correction mul reduce layer"""
  766. if context.get_context('device_target') == "Ascend":
  767. from mindspore.ops._op_impl._custom_op import correction_mul_grad
  768. self.channel_axis = channel_axis
  769. self.init_prim_io_names(inputs=['mul_dx'],
  770. outputs=['d_gamma'])
  771. def infer_shape(self, mul_dx_shape):
  772. return [mul_dx_shape[self.channel_axis]]
  773. def infer_dtype(self, mul_dx_type):
  774. return mul_dx_type
  775. class BatchNormFold2(PrimitiveWithInfer):
  776. """
  777. Scales the bias with a correction factor to the long term statistics
  778. prior to quantization. This ensures that there is no jitter in the quantized bias
  779. due to batch to batch variation.
  780. Inputs:
  781. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  782. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  783. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  784. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  785. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  786. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  787. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  788. - **global_step** (Tensor) - Tensor to record current global step.
  789. Outputs:
  790. - **y** (Tensor) - Tensor has the same shape as x.
  791. Examples:
  792. >>> batch_norm_fold2 = P.BatchNormFold2()
  793. >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
  794. >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
  795. >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
  796. >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
  797. >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
  798. >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
  799. >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
  800. >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
  801. >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
  802. >>> running_std, running_mean, global_step)
  803. """
  804. channel_axis = 1
  805. @prim_attr_register
  806. def __init__(self, freeze_bn=0):
  807. """Initialize conv2d fold layer"""
  808. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  809. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
  810. 'running_std', 'running_mean', 'global_step'],
  811. outputs=['y'])
  812. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
  813. running_mean_shape, global_step_shape):
  814. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  815. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  816. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  817. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  818. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  819. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  820. Rel.EQ, self.name)
  821. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  822. return x_shape
  823. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
  824. running_mean_type, global_step_type):
  825. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  826. "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
  827. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  828. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  829. return x_type
  830. class BatchNormFold2Grad(PrimitiveWithInfer):
  831. r"""
  832. Performs grad of CorrectionAddGrad operation.
  833. Examples:
  834. >>> bnf2_grad = P.BatchNormFold2Grad()
  835. >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
  836. >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
  837. >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
  838. >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
  839. >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
  840. >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
  841. >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
  842. >>> global_step = Tensor(np.array([-2]), mindspore.int32)
  843. >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
  844. """
  845. channel_axis = 1
  846. @prim_attr_register
  847. def __init__(self, freeze_bn=0):
  848. """Initialize MulFold layer"""
  849. self.freeze_bn = freeze_bn
  850. self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
  851. 'batch_std', 'batch_mean',
  852. 'running_std', 'running_mean', 'global_step'],
  853. outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
  854. def infer_shape(self, dout_shape, x_shape, gamma_shape,
  855. batch_std_shape, batch_mean_shape,
  856. running_std_shape, running_mean_shape, global_step_shape):
  857. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  858. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  859. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  860. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  861. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  862. Rel.EQ, self.name)
  863. validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
  864. return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
  865. def infer_dtype(self, dout_type, x_type, gamma_type,
  866. batch_std_type, batch_mean_type,
  867. running_std_type, running_mean_type, global_step_type):
  868. validator.check("batch_std type", batch_std_type,
  869. "batch_mean type", batch_mean_type)
  870. validator.check("batch_std type", batch_std_type,
  871. "gamma type", gamma_type)
  872. validator.check("batch_std type", batch_std_type,
  873. "running_std type", running_std_type)
  874. validator.check("batch_std type", batch_std_type,
  875. "running_mean type", running_mean_type)
  876. validator.check("batch_std_type", batch_std_type,
  877. "dout type", dout_type)
  878. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  879. "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
  880. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  881. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  882. return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
  883. class BatchNormFoldD(PrimitiveWithInfer):
  884. """Performs grad of _BatchNormFold operation."""
  885. @prim_attr_register
  886. def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
  887. """Initialize _BatchNormFold layer"""
  888. from mindspore.ops._op_impl._custom_op import batchnorm_fold
  889. self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
  890. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  891. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  892. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  893. self.data_format = "NCHW"
  894. self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
  895. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
  896. 'mean_updated', 'variance_updated'])
  897. def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
  898. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  899. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
  900. return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
  901. def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
  902. validator.check("input type", x_type, "mean type", mean_type)
  903. validator.check("input type", x_type, "variance type", variance_type)
  904. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  905. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  906. return x_type, x_type, x_type, x_type, x_type, x_type, x_type
  907. class BatchNormFoldGradD(PrimitiveWithInfer):
  908. """Performs grad of _BatchNormFoldGrad operation."""
  909. @prim_attr_register
  910. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  911. """Initialize _BatchNormFoldGrad layer"""
  912. from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
  913. self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
  914. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  915. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  916. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
  917. outputs=['dx'])
  918. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
  919. validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
  920. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
  921. validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
  922. validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
  923. return x_shape
  924. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
  925. validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
  926. validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
  927. validator.check("input type", x_type, "batch_mean type", batch_mean_type)
  928. validator.check("input type", x_type, "batch_std type", batch_std_type)
  929. args = {"input type": x_type}
  930. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  931. return x_type
  932. class BatchNormFold2_D(PrimitiveWithInfer):
  933. """
  934. Scales the bias with a correction factor to the long term statistics
  935. prior to quantization. This ensures that there is no jitter in the quantized bias
  936. due to batch to batch variation.
  937. Inputs:
  938. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  939. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  940. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  941. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  942. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  943. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  944. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  945. - **global_step** (Tensor) - Tensor to record current global step.
  946. Outputs:
  947. - **y** (Tensor) - Tensor has the same shape as x.
  948. """
  949. channel_axis = 1
  950. @prim_attr_register
  951. def __init__(self, freeze_bn=0):
  952. """Initialize conv2d fold layer"""
  953. from mindspore.ops._op_impl._custom_op import batchnorm_fold2
  954. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  955. outputs=['y'])
  956. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
  957. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  958. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  959. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  960. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  961. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
  962. Rel.EQ, self.name)
  963. return x_shape
  964. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
  965. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  966. "beta": beta_type, "gamma": gamma_type, "x": x_type}
  967. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  968. return x_type
  969. class BatchNormFold2GradD(PrimitiveWithInfer):
  970. """Performs grad of CorrectionAddGrad operation."""
  971. channel_axis = 1
  972. @prim_attr_register
  973. def __init__(self, freeze_bn=False):
  974. """Initialize MulFold layer"""
  975. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
  976. self.freeze_bn = freeze_bn
  977. self.init_prim_io_names(
  978. inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
  979. outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
  980. def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
  981. batch_mean_shape, running_std_shape):
  982. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  983. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  984. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  985. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
  986. Rel.EQ, self.name)
  987. return gamma_shape, gamma_shape, gamma_shape, dout_shape
  988. def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
  989. batch_mean_type, running_std_type):
  990. validator.check("batch_std type", batch_std_type,
  991. "batch_mean type", batch_mean_type)
  992. validator.check("batch_std type", batch_std_type,
  993. "gamma type", gamma_type)
  994. validator.check("batch_std type", batch_std_type,
  995. "running_std type", running_std_type)
  996. validator.check("batch_std_type", batch_std_type,
  997. "dout type", dout_type)
  998. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  999. "running_std": running_std_type, "dout": dout_type}
  1000. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  1001. return gamma_type, gamma_type, gamma_type, gamma_type
  1002. class BatchNormFold2GradReduce(PrimitiveWithInfer):
  1003. """Performs grad of CorrectionAddGrad operation."""
  1004. channel_axis = 1
  1005. @prim_attr_register
  1006. def __init__(self, freeze_bn=False):
  1007. """Initialize MulFold layer"""
  1008. from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
  1009. self.freeze_bn = freeze_bn
  1010. self.init_prim_io_names(inputs=['dout', 'x'],
  1011. outputs=['dout_reduce', 'dout_x_reduce'])
  1012. def infer_shape(self, dout_shape, x_shape):
  1013. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  1014. return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
  1015. def infer_dtype(self, dout_type, x_type):
  1016. validator.check("dout type", dout_type, "x type", x_type)
  1017. return dout_type, dout_type