You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_quant_ops.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0(the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http: // www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for quantization."""
  16. from ..._checkparam import Validator as validator
  17. from ..._checkparam import Rel
  18. from ..primitive import PrimitiveWithInfer, prim_attr_register
  19. from ...common import dtype as mstype
  20. __all__ = ["FakeQuantWithMinMax",
  21. "FakeQuantWithMinMaxGrad",
  22. "FakeQuantWithMinMaxPerChannel",
  23. "FakeQuantWithMinMaxPerChannelGrad",
  24. "BatchNormFold",
  25. "BatchNormFoldGrad",
  26. "CorrectionMul",
  27. "CorrectionMulGrad",
  28. "BatchNormFold2",
  29. "BatchNormFold2Grad",
  30. ]
  31. class FakeQuantWithMinMax(PrimitiveWithInfer):
  32. r"""
  33. Simulate the quantize and dequantize operations in training time.
  34. Args:
  35. num_bits (int) : Number bits for aware quantilization. Default: 8.
  36. ema (bool): Use EMA algorithm update value min and max. Default: False.
  37. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  38. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
  39. simulate aware quantize funcion. After delay step in training time begin simulate the aware
  40. quantize funcion. Default: 0.
  41. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  42. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  43. training (bool): Training the network or not. Default: True.
  44. Inputs:
  45. - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
  46. - **min** (Tensor) : Value of the min range of the input data x.
  47. - **max** (Tensor) : Value of the max range of the input data x.
  48. Outputs:
  49. - Tensor: Simulate quantize tensor of x.
  50. Examples:
  51. >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
  52. >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
  53. >>> max_tensor = Tensor(np.array([6]), mstype.float32)
  54. >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
  55. """
  56. support_quant_bit = [4, 7, 8]
  57. @prim_attr_register
  58. def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
  59. training=True):
  60. """init FakeQuantWithMinMax OP"""
  61. if num_bits not in self.support_quant_bit:
  62. raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
  63. if ema and not ema_decay:
  64. raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  65. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  66. self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
  67. self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
  68. self.training = validator.check_value_type('training', training, (bool,), self.name)
  69. self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
  70. self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
  71. self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
  72. self.init_prim_io_names(inputs=['x', 'min', 'max'],
  73. outputs=['out'])
  74. def infer_shape(self, x_shape, min_shape, max_shape):
  75. validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
  76. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  77. validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
  78. return x_shape
  79. def infer_dtype(self, x_type, min_type, max_type):
  80. valid_types = (mstype.float16, mstype.float32)
  81. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  82. validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
  83. validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
  84. return x_type
  85. class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
  86. """Performs grad of FakeQuantWithMinMax operation."""
  87. support_quant_bit = [4, 8]
  88. @prim_attr_register
  89. def __init__(self, num_bits=8, quant_delay=0):
  90. if num_bits not in self.support_quant_bit:
  91. raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
  92. self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
  93. self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
  94. self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  95. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  96. validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
  97. validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
  98. validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
  99. return dout_shape
  100. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  101. valid_types = (mstype.float16, mstype.float32)
  102. validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
  103. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  104. validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
  105. validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
  106. return dout_type
  107. class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
  108. r"""
  109. Simulate the quantize and dequantize operations in training time base on per channel.
  110. Args:
  111. num_bits (int) : Number bits to quantilization. Default: 8.
  112. ema (bool): Use EMA algorithm update tensor min and tensor max. Default: False.
  113. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
  114. quant_delay (int): Quantilization delay parameter. Before delay step in training time not
  115. update the weight data to simulate quantize operation. After delay step in training time
  116. begin simulate the quantize operation. Default: 0.
  117. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
  118. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
  119. training (bool): Training the network or not. Default: True.
  120. Inputs:
  121. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
  122. - **min** (int, float) : Value of the min range of the input data.
  123. - **max** (int, float) : Value of the max range of the input data.
  124. Outputs:
  125. - Tensor, has the same type as input.
  126. Examples:
  127. >>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32)
  128. >>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32)
  129. >>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32)
  130. >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
  131. """
  132. support_quant_bit = [4, 8]
  133. channel_idx = 0
  134. @prim_attr_register
  135. def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False,
  136. training=True):
  137. """init FakeQuantWithMinMaxPerChannel OP"""
  138. if num_bits not in self.support_quant_bit:
  139. raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.")
  140. if ema and not ema_decay:
  141. raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
  142. self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
  143. self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
  144. self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
  145. self.training = validator.check_value_type('training', training, (bool,), self.name)
  146. self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
  147. self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
  148. self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
  149. self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
  150. def infer_shape(self, x_shape, min_shape, max_shape):
  151. validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
  152. validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
  153. validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
  154. return x_shape
  155. def infer_dtype(self, x_type, min_type, max_type):
  156. valid_types = (mstype.float16, mstype.float32)
  157. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  158. validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
  159. validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
  160. return x_type
  161. class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
  162. """Performs grad of FakeQuantWithMinMaxPerChannel operation."""
  163. support_quant_bit = [4, 8]
  164. @prim_attr_register
  165. def __init__(self, num_bits=8, quant_delay=0):
  166. """init FakeQuantWithMinMaxPerChannel Fill"""
  167. if num_bits not in self.support_quant_bit:
  168. raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
  169. self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
  170. self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
  171. self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
  172. def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
  173. validator.check("dout shape", dout_shape, "x shape", x_shape)
  174. validator.check("min shape", min_shape, "max shape", max_shape)
  175. return dout_shape
  176. def infer_dtype(self, dout_type, x_type, min_type, max_type):
  177. valid_types = (mstype.float16, mstype.float32)
  178. validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
  179. validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
  180. validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
  181. validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
  182. return dout_type
  183. class BatchNormFold(PrimitiveWithInfer):
  184. """
  185. Batch normalization folded.
  186. Args:
  187. momentum (float): Momentum value should be [0, 1]. Default: 0.1.
  188. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
  189. float32 else 1e-3. Default: 1e-5.
  190. is_training (bool): In training mode set True, else set False. Default: True.
  191. freeze_bn (int): Delay in steps at which computation switches from regular batch
  192. norm to frozen mean and std. Default: 0.
  193. Inputs:
  194. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  195. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  196. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  197. - **global_step** (Tensor) - Tensor to record current global step.
  198. Outputs:
  199. Tuple of 4 Tensor, the normalized input and the updated parameters.
  200. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  201. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  202. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  203. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  204. """
  205. channel = 1
  206. @prim_attr_register
  207. def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
  208. """init batch norm fold layer"""
  209. self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
  210. self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
  211. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  212. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  213. self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
  214. outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
  215. def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
  216. validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
  217. validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name)
  218. validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
  219. return mean_shape, mean_shape, mean_shape, mean_shape
  220. def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
  221. validator.check("input type", x_type, "mean type", mean_type)
  222. validator.check("input type", x_type, "variance type", variance_type)
  223. args = {"x": x_type, "mean": mean_type, "variance": variance_type}
  224. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  225. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  226. return x_type, x_type, x_type, x_type
  227. class BatchNormFoldGrad(PrimitiveWithInfer):
  228. """Performs grad of BatchNormFold operation."""
  229. channel = 1
  230. @prim_attr_register
  231. def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
  232. """init BatchNormGrad layer"""
  233. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  234. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  235. self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
  236. self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
  237. outputs=['dx'])
  238. def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
  239. global_step_shape):
  240. validator.check("d_batch_mean shape", d_batch_mean_shape,
  241. "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
  242. validator.check("d_batch_mean shape", d_batch_mean_shape,
  243. "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  244. validator.check("d_batch_mean shape", d_batch_mean_shape,
  245. "batch_std shape", batch_std_shape, Rel.EQ, self.name)
  246. validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ,
  247. self.name)
  248. validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
  249. return x_shape
  250. def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
  251. global_step_type):
  252. args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
  253. "batch_mean": batch_mean_type, "batch_std": batch_std_type}
  254. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  255. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  256. return x_type
  257. class CorrectionMul(PrimitiveWithInfer):
  258. """
  259. Scale the weights with a correction factor to the long term statistics
  260. prior to quantization. This ensures that there is no jitter in the quantized weights
  261. due to batch to batch variation.
  262. Inputs:
  263. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  264. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  265. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  266. Outputs:
  267. - **out** (Tensor) - Tensor has the same shape as x.
  268. """
  269. channel = 0
  270. @prim_attr_register
  271. def __init__(self):
  272. """init correction mul layer"""
  273. self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
  274. outputs=['out'])
  275. def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
  276. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  277. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
  278. Rel.EQ, self.name)
  279. return x_shape
  280. def infer_dtype(self, x_type, batch_std_type, running_std_type):
  281. args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
  282. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  283. return x_type
  284. class CorrectionMulGrad(PrimitiveWithInfer):
  285. """Performs grad of CorrectionMul operation."""
  286. channel = 0
  287. @prim_attr_register
  288. def __init__(self):
  289. """init correction mul layer"""
  290. self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
  291. outputs=['dx', 'd_gamma'])
  292. def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
  293. validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
  294. validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel],
  295. Rel.EQ, self.name)
  296. validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel],
  297. Rel.EQ, self.name)
  298. return x_shape, gamma_shape
  299. def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
  300. args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
  301. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  302. return x_type, x_type
  303. class BatchNormFold2(PrimitiveWithInfer):
  304. """
  305. Scale the bias with a correction factor to the long term statistics
  306. prior to quantization. This ensures that there is no jitter in the quantized bias
  307. due to batch to batch variation.
  308. Inputs:
  309. - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
  310. - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
  311. - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  312. - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
  313. - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  314. - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
  315. - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  316. - **global_step** (Tensor) - Tensor to record current global step.
  317. Outputs:
  318. - **y** (Tensor) - Tensor has the same shape as x.
  319. """
  320. channel = 1
  321. @prim_attr_register
  322. def __init__(self, freeze_bn=0):
  323. """init conv2d fold layer"""
  324. self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
  325. self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
  326. 'running_std', 'running_mean', 'global_step'],
  327. outputs=['y'])
  328. def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
  329. running_mean_shape, global_step_shape):
  330. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  331. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  332. validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
  333. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  334. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
  335. validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
  336. Rel.EQ, self.name)
  337. validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
  338. return x_shape
  339. def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
  340. running_mean_type, global_step_type):
  341. args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
  342. "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
  343. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  344. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  345. return x_type
  346. class BatchNormFold2Grad(PrimitiveWithInfer):
  347. """Performs grad of CorrectionAddGrad operation."""
  348. channel = 1
  349. @prim_attr_register
  350. def __init__(self, freeze_bn=0):
  351. """init MulFold layer"""
  352. self.freeze_bn = freeze_bn
  353. self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
  354. 'batch_std', 'batch_mean',
  355. 'running_std', 'running_mean', 'global_step'],
  356. outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
  357. def infer_shape(self, dout_shape, x_shape, gamma_shape,
  358. batch_std_shape, batch_mean_shape,
  359. running_std_shape, running_mean_shape, global_step_shape):
  360. validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
  361. validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
  362. validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
  363. validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
  364. validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel],
  365. Rel.EQ, self.name)
  366. validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
  367. return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
  368. def infer_dtype(self, dout_type, x_type, gamma_type,
  369. batch_std_type, batch_mean_type,
  370. running_std_type, running_mean_type, global_step_type):
  371. validator.check("batch_std type", batch_std_type,
  372. "batch_mean type", batch_mean_type)
  373. validator.check("batch_std type", batch_std_type,
  374. "gamma type", gamma_type)
  375. validator.check("batch_std type", batch_std_type,
  376. "running_std type", running_std_type)
  377. validator.check("batch_std type", batch_std_type,
  378. "running_mean type", running_mean_type)
  379. validator.check("batch_std_type", batch_std_type,
  380. "dout type", dout_type)
  381. args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
  382. "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
  383. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  384. validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
  385. return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type