You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_grad_ops.py 38 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for gradients."""
  16. from ..._c_expression import signature_rw as sig_rw
  17. from ..._c_expression import signature_kind as sig_kind
  18. from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
  19. from ..._checkparam import Validator as validator, Rel
  20. from .._utils import _get_concat_offset
  21. from ...common import dtype as mstype
  22. class AbsGrad(PrimitiveWithInfer):
  23. """Computes gradients for abs operation."""
  24. @prim_attr_register
  25. def __init__(self):
  26. """init AbsGrad"""
  27. def infer_shape(self, y, dy):
  28. return y
  29. def infer_dtype(self, y, dy):
  30. return y
  31. class ACosGrad(PrimitiveWithInfer):
  32. """
  33. Computes ACosGrad of input element-wise.
  34. Returns:
  35. Tensor, has the same type as input.
  36. """
  37. @prim_attr_register
  38. def __init__(self):
  39. """init ACosGrad"""
  40. def infer_shape(self, x, dout):
  41. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  42. return x
  43. def infer_dtype(self, x, dout):
  44. args = {"x": x, "dout": dout}
  45. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  46. return x
  47. class BatchNormGrad(PrimitiveWithInfer):
  48. """Performs grad of BatchNorm operation."""
  49. @prim_attr_register
  50. def __init__(self, is_training=False, epsilon=1e-5):
  51. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  52. self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
  53. self.add_prim_attr('data_format', "NCHW")
  54. def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
  55. validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
  56. return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
  57. def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type):
  58. return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
  59. class BiasAddGrad(Primitive):
  60. """Computes gradients of BiasAdd."""
  61. @prim_attr_register
  62. def __init__(self):
  63. self.init_prim_io_names(inputs=['dout'], outputs=['output'])
  64. self.add_prim_attr('data_format', 'NCHW')
  65. def __call__(self, d_output):
  66. raise NotImplementedError
  67. class BinaryCrossEntropyGrad(PrimitiveWithInfer):
  68. """Computes gradients for `BinaryCrossEntropy` operation."""
  69. @prim_attr_register
  70. def __init__(self, reduction='mean'):
  71. self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
  72. def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
  73. validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
  74. if weight_shape:
  75. validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
  76. return x_shape
  77. def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
  78. args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
  79. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  80. if weight_type:
  81. validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
  82. return x_type
  83. class ConcatOffset(PrimitiveWithInfer):
  84. """primitive for computing Concat's gradient."""
  85. @prim_attr_register
  86. def __init__(self, N=2, axis=0):
  87. """init ConcatOffset"""
  88. def __infer__(self, input_x):
  89. axis = self.axis
  90. x_shp = input_x['shape']
  91. x_type = input_x['dtype']
  92. offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
  93. self.add_prim_attr('T', x_type[0].element_type())
  94. offset_values = []
  95. for i in range(len(x_shp)):
  96. values = []
  97. for j in range(len(x_shp[0])):
  98. value = 0
  99. if j == axis:
  100. value = offset[i]
  101. values.append(value)
  102. offset_values.append(tuple(values))
  103. out = {'shape': None,
  104. 'dtype': None,
  105. 'value': tuple(offset_values)}
  106. return out
  107. class Conv2DBackpropFilter(PrimitiveWithInfer):
  108. """
  109. Computes the gradients of convolution with respect to the filter.
  110. Args:
  111. out_channel (int): The dimensionality of the output space.
  112. kernel_size (Union[int, tuple[int]]): The size of the convolution window.
  113. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  114. pad (int): The pad value to fill. Default: 0.
  115. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
  116. 2 deconvolution, 3 depthwise convolution. Default: 1.
  117. stride (tuple): The stride to apply conv filter. Default: (1, 1).
  118. dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1).
  119. group (int): Splits input into groups. Default: 1.
  120. Returns:
  121. Tensor, the gradients of convolution.
  122. """
  123. @prim_attr_register
  124. def __init__(self,
  125. out_channel,
  126. kernel_size,
  127. pad_mode="valid",
  128. pad=0,
  129. pad_list=(0, 0, 0, 0),
  130. mode=1,
  131. stride=(1, 1),
  132. dilation=(1, 1, 1, 1),
  133. group=1):
  134. """init Convolution"""
  135. self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
  136. self.out_channel = out_channel
  137. self.kernel_size = kernel_size
  138. self.mode = mode
  139. pad_mode = pad_mode.upper()
  140. self.add_prim_attr('pad_mode', pad_mode)
  141. self.pad = pad
  142. if isinstance(stride, tuple) and len(stride) == 4:
  143. self.stride = (stride[2], stride[3])
  144. self.add_prim_attr('stride', self.stride)
  145. self.dilation = dilation
  146. self.group = group
  147. self.add_prim_attr('data_format', "NCHW")
  148. def __infer__(self, doutput, x, w_size):
  149. w_size_v = w_size['value']
  150. validator.check_value_type('w_size', w_size_v, [tuple], self.name)
  151. for i, dim_len in enumerate(w_size_v):
  152. validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
  153. args = {"x": x['dtype'], "doutput": doutput['dtype']}
  154. validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
  155. out = {
  156. 'value': None,
  157. 'shape': w_size_v,
  158. 'dtype': doutput['dtype'],
  159. }
  160. return out
  161. class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
  162. """
  163. Returns the gradient of filter for DepthwiseConv2dNative.
  164. Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
  165. Refer to class DepthwiseConv2dNative for more details.
  166. Args:
  167. channel_multiplier (int): The multipiler for the original output conv.
  168. kernel_size (int or tuple): The size of the conv kernel.
  169. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution,
  170. 2 deconvolution,3 depthwise convolution. Defaul: 3.
  171. pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
  172. pad (int): The pad value to fill. Default: 0.
  173. pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
  174. stride (int): The stride to apply conv filter. Default: 1.
  175. dilation (int): Specifies the space to use between kernel elements. Default: 1.
  176. group (int): Splits input into groups. Default: 1.
  177. Returns:
  178. Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
  179. """
  180. @prim_attr_register
  181. def __init__(self,
  182. channel_multiplier,
  183. kernel_size,
  184. pad_mode="valid",
  185. pad=0,
  186. pads=(0, 0, 0, 0),
  187. mode=3,
  188. stride=1,
  189. dilation=1,
  190. group=1):
  191. """init Convolution"""
  192. self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
  193. self.channel_multiplier = channel_multiplier
  194. self.kernel_size = kernel_size
  195. self.mode = mode
  196. self.pad_mode = pad_mode
  197. self.pad = pad
  198. self.pads = pads
  199. self.stride = stride
  200. self.dilation = dilation
  201. self.group = group
  202. self.add_prim_attr('data_format', "NCHW")
  203. def __call__(self, x, w_size, dout):
  204. raise NotImplementedError
  205. def __infer__(self, x, w_size, dout):
  206. w_size_v = w_size['value']
  207. args = {'x': x['dtype'], 'dout': dout['dtype']}
  208. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  209. out = {
  210. 'value': None,
  211. 'shape': w_size_v,
  212. 'dtype': dout['dtype'],
  213. }
  214. return out
  215. class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
  216. """
  217. Returns the gradient of input for DepthwiseConv2dNative.
  218. Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
  219. Args:
  220. channel_multiplier (int): The multipiler for the original output conv.
  221. kernel_size (int or tuple): The size of the conv kernel.
  222. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
  223. 2 deconvolution,3 depthwise convolution. Default: 3.
  224. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  225. pad (int): the pad value to fill. Default: 0.
  226. pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
  227. stride (int): the stride to apply conv filter. Default: 1.
  228. dilation (int): Specifies the space to use between kernel elements. Default: 1.
  229. group (int): Splits input into groups. Default: 1.
  230. Returns:
  231. Tensor, the value is the gradient of input for DepthwiseConv2dNative.
  232. """
  233. @prim_attr_register
  234. def __init__(self,
  235. channel_multiplier,
  236. kernel_size,
  237. pad_mode="valid",
  238. pad=0,
  239. pads=(0, 0, 0, 0),
  240. mode=3,
  241. stride=1,
  242. dilation=1,
  243. group=1):
  244. """init Convolution"""
  245. self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
  246. self.channel_multiplier = channel_multiplier
  247. self.kernel_size = kernel_size
  248. self.mode = mode
  249. self.pad_mode = pad_mode
  250. self.pad = pad
  251. self.pads = pads
  252. self.stride = stride
  253. self.dilation = dilation
  254. self.group = group
  255. self.add_prim_attr('data_format', "NCHW")
  256. def __call__(self, x_size, w, dout):
  257. raise NotImplementedError
  258. def __infer__(self, x_size, w, dout):
  259. args = {'w': w['dtype'], 'dout': dout['dtype']}
  260. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  261. x_size_v = x_size['value']
  262. out = {
  263. 'value': None,
  264. 'shape': x_size_v,
  265. 'dtype': dout['dtype'],
  266. }
  267. return out
  268. class FlattenGrad(PrimitiveWithInfer):
  269. """Performs gradients of Flatten."""
  270. @prim_attr_register
  271. def __init__(self):
  272. self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
  273. def __infer__(self, *args):
  274. out = {
  275. 'value': None,
  276. 'shape': args[1]['value'],
  277. 'dtype': args[0]['dtype'],
  278. }
  279. return out
  280. class FusedBatchNormGrad(Primitive):
  281. """Gradients of FusedBatchNorm operation."""
  282. @prim_attr_register
  283. def __init__(self, epsilon=0.0, momentum=0.1):
  284. self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'],
  285. outputs=['dx', 'bn_scale', 'bn_bias'])
  286. def __call__(self, dy, x, scale, save_mean, save_inv_variance):
  287. raise NotImplementedError
  288. class GeluGrad(PrimitiveWithInfer):
  289. """Gradients of Gelu operation."""
  290. @prim_attr_register
  291. def __init__(self):
  292. """init GeluGrad"""
  293. def infer_shape(self, y_backprop_shape, x_shape, y_shape):
  294. return x_shape
  295. def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
  296. validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
  297. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  298. validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
  299. return x_dtype
  300. class _PoolGrad(PrimitiveWithInfer):
  301. """Gradients of the max/avg pool operation."""
  302. @prim_attr_register
  303. def __init__(self, ksize, strides, padding="VALID"):
  304. self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
  305. validator.check_value_type('ksize', ksize, [int, tuple], self.name)
  306. validator.check_value_type('strides', strides, [int, tuple], self.name)
  307. self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
  308. self.add_prim_attr("padding", self.padding)
  309. self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
  310. if not self.is_maxpoolgradwithargmax:
  311. self.add_prim_attr('data_format', "NCHW")
  312. def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
  313. validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
  314. error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
  315. f"or a tuple of two or four positive int numbers, but got {arg_val}")
  316. if isinstance(arg_val, int):
  317. ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
  318. elif len(arg_val) == 2:
  319. ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
  320. elif len(arg_val) == 4:
  321. ret = arg_val
  322. else:
  323. raise error_msg
  324. # whether all elements of tuple are positive integers
  325. for item in ret:
  326. if not isinstance(item, int) or item <= 0:
  327. raise error_msg
  328. return ret
  329. self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
  330. self.add_prim_attr("ksize", self.ksize)
  331. self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
  332. self.add_prim_attr("strides", self.strides)
  333. class AvgPoolGrad(_PoolGrad):
  334. """Gradients of the avg pool operation."""
  335. @prim_attr_register
  336. def __init__(self, ksize=1, strides=1, padding="VALID"):
  337. super(AvgPoolGrad, self).__init__(ksize, strides, padding)
  338. def __infer__(self, origin_input, dout):
  339. out = {
  340. 'value': None,
  341. 'shape': tuple(origin_input['value']),
  342. 'dtype': dout['dtype'],
  343. }
  344. return out
  345. class AvgPoolGradGpu(_PoolGrad):
  346. """Gradients of the avg pool operation for gpu."""
  347. @prim_attr_register
  348. def __init__(self, ksize=1, strides=1, padding="VALID"):
  349. super(AvgPoolGradGpu, self).__init__(ksize, strides, padding)
  350. def infer_shape(self, x1_shape, x2_shape, grad_shape):
  351. return x1_shape
  352. def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
  353. return x1_dtype
  354. class MaxPoolGrad(_PoolGrad):
  355. """Performs gradients of the max pool operation."""
  356. @prim_attr_register
  357. def __init__(self, ksize=1, strides=1, padding="VALID"):
  358. super(MaxPoolGrad, self).__init__(ksize, strides, padding)
  359. def infer_shape(self, x1_shape, x2_shape, grad_shape):
  360. return x1_shape
  361. def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
  362. return x1_dtype
  363. class MaximumGrad(Primitive):
  364. """Grad for maximum."""
  365. @prim_attr_register
  366. def __init__(self, grad_x=True, grad_y=True):
  367. """Init MaximumGrad"""
  368. def __call__(self, x, y, dout):
  369. raise NotImplementedError
  370. class MaxPoolGradWithArgmax(_PoolGrad):
  371. """Computes the gradients of MaxPoolWithArgmax."""
  372. @prim_attr_register
  373. def __init__(self, ksize=1, strides=1, padding="VALID",):
  374. self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
  375. super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)
  376. def infer_shape(self, x_shape, grad_shape, argmax_shape):
  377. if not grad_shape:
  378. raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.")
  379. return x_shape
  380. def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
  381. return grad_dtype
  382. class MinimumGrad(Primitive):
  383. """Grad for minimum."""
  384. @prim_attr_register
  385. def __init__(self, grad_x=True, grad_y=True):
  386. """Init MinimumGrad"""
  387. def __call__(self, x, y, dout):
  388. raise NotImplementedError
  389. class L2NormalizeGrad(PrimitiveWithInfer):
  390. r"""
  391. Gradients of L2 normalize.
  392. Args:
  393. axis (int): The begin axis for the input to apply L2 normalize. Default: 0.
  394. epsilon (float): A small value added for numerical stability. Default: 1e-4.
  395. Inputs:
  396. - **input_x** (Tensor) - Should be the input `weight` of forward operator L2Normalize.
  397. - **out** (Tensor) - Should be the output of forward operator L2Normalize.
  398. - **dout** (Tensor) - The backprop of the next layer.
  399. Outputs:
  400. Tensor, gradients of L2Normalize `input_x`.
  401. """
  402. @prim_attr_register
  403. def __init__(self, axis=0, epsilon=1e-4):
  404. validator.check_value_type('axis', axis, [int], self.name)
  405. validator.check_value_type('epsilon', epsilon, [int, float], self.name)
  406. def infer_shape(self, input_x, out, dout):
  407. validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
  408. validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
  409. return input_x
  410. def infer_dtype(self, input_x, out, dout):
  411. args = {'input_x': input_x, 'out': out, 'dout': dout}
  412. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  413. return input_x
  414. class LayerNormGrad(Primitive):
  415. """
  416. Applies the layer normalization to the input array.
  417. This operator will calculate the input gradients of layernorm.
  418. Args:
  419. begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
  420. begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
  421. Returns:
  422. tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta).
  423. """
  424. @prim_attr_register
  425. def __init__(self, begin_norm_axis=1, begin_params_axis=1):
  426. """init"""
  427. self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
  428. self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
  429. def __call__(self, x, dy, variance, mean, gamma):
  430. raise NotImplementedError
  431. class LogSoftmaxGrad(PrimitiveWithInfer):
  432. """Computes gradient for the Log Softmax activation."""
  433. @prim_attr_register
  434. def __init__(self, axis=-1):
  435. """init LogSoftmaxGrad"""
  436. validator.check_value_type("axis", axis, [int], self.name)
  437. def infer_shape(self, dout, logits):
  438. rank = len(logits)
  439. validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name)
  440. return logits
  441. def infer_dtype(self, dout, logits):
  442. validator.check_subclass("logits", logits, mstype.tensor, self.name)
  443. return logits
  444. class LSTMGradData(PrimitiveWithInfer):
  445. """Computes the data gradients of LSTM."""
  446. @prim_attr_register
  447. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  448. self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
  449. self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
  450. self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
  451. self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
  452. self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
  453. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  454. self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
  455. if bidirectional:
  456. self.num_directions = 2
  457. else:
  458. self.num_directions = 1
  459. def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
  460. hx_shape, cx_shape, reserve_shape, state_shape):
  461. # dhy and dcy should be same shape
  462. validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
  463. validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
  464. validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
  465. validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
  466. validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
  467. validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
  468. validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
  469. # dy: (seq_len, batch_size, hidden_size * num_directions)
  470. validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
  471. validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
  472. validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
  473. # (seq_len, batch_size, input_size)
  474. dx_shape = (y_shape[0], y_shape[1], self.input_size)
  475. dhx_shape = dhy_shape
  476. dcx_shape = dcy_shape
  477. return (dx_shape, dhx_shape, dcx_shape)
  478. def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
  479. hx_dtype, cx_dtype, reserve_dtype, state_dtype):
  480. args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
  481. validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
  482. return (dy_dtype, dy_dtype, dy_dtype)
  483. class LSTMGradWeight(PrimitiveWithInfer):
  484. """Computes the weight gradients of LSTM."""
  485. @prim_attr_register
  486. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  487. self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
  488. self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
  489. self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
  490. self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
  491. self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
  492. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  493. self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
  494. if bidirectional:
  495. self.num_directions = 2
  496. else:
  497. self.num_directions = 1
  498. def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
  499. weight_size = 0
  500. gate_size = 4 * self.hidden_size
  501. for layer in range(self.num_layers):
  502. for _ in range(self.num_directions):
  503. input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
  504. weight_size += gate_size * input_layer_size
  505. weight_size += gate_size * self.hidden_size
  506. if self.has_bias:
  507. weight_size += 2 * gate_size
  508. return (weight_size, 1, 1)
  509. def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
  510. return hx_dtype
  511. class PReLUGrad(PrimitiveWithInfer):
  512. r"""
  513. Gradients of PReLU operation.
  514. Inputs:
  515. - **y_backprop** (Tensor) - Representing the backprop of the next layer.
  516. - **input_x** (Tensor) - Should be the input `input_x` of forward operator PRelu.
  517. - **weight** (Tensor) - Float Tensor, w > 0, should be the input `weight` of forward operator PRelu.
  518. Outputs:
  519. Tensor, with the same type as `input_x`.
  520. """
  521. @prim_attr_register
  522. def __init__(self):
  523. pass
  524. def infer_shape(self, y_backprop_shape, A_shape, w_shape):
  525. return y_backprop_shape, w_shape
  526. def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
  527. valid_types = (mstype.float16, mstype.float32)
  528. validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
  529. validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
  530. validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
  531. return y_backprop_dtype, w_dtype
  532. class ReluGrad(Primitive):
  533. """Performs grad of Relu operation."""
  534. @prim_attr_register
  535. def __init__(self):
  536. """init ReluGrad"""
  537. self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
  538. def __call__(self, y_backprop, x):
  539. raise NotImplementedError
  540. class ReLU6Grad(PrimitiveWithInfer):
  541. """Performs grad of ReLU6 operation."""
  542. @prim_attr_register
  543. def __init__(self):
  544. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  545. def __call__(self, y_grad, x):
  546. raise NotImplementedError
  547. def infer_shape(self, y_grad_shape, x_shape):
  548. return x_shape
  549. def infer_dtype(self, y_grad_dtype, x_dtype):
  550. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  551. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  552. return x_dtype
  553. class ReluGradV2(PrimitiveWithInfer):
  554. """Performs grad of ReLUV2 operation."""
  555. @prim_attr_register
  556. def __init__(self):
  557. self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
  558. def __call__(self, gradients, mask):
  559. raise NotImplementedError
  560. def infer_shape(self, gradients_shape, mask_shape):
  561. return gradients_shape
  562. def infer_dtype(self, gradients_dtype, mask_dtype):
  563. validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
  564. validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
  565. return gradients_dtype
  566. class EluGrad(PrimitiveWithInfer):
  567. """Performs grad of Elu operation."""
  568. @prim_attr_register
  569. def __init__(self):
  570. """Init EluGrad"""
  571. def infer_shape(self, y_grad_shape, x_shape):
  572. return x_shape
  573. def infer_dtype(self, y_grad_dtype, x_dtype):
  574. args = {'y_grad': y_grad_dtype, 'x': x_dtype}
  575. validator.check_tensor_type_same(args, mstype.float_type, self.name)
  576. return x_dtype
  577. class ResizeBilinearGrad(PrimitiveWithInfer):
  578. """Performs grad of ResizeBilinear operation."""
  579. @prim_attr_register
  580. def __init__(self, align_corners=False):
  581. """init"""
  582. def infer_shape(self, dout_shape, orig_shape):
  583. return orig_shape
  584. def infer_dtype(self, dout_dtype, orig_type):
  585. return dout_dtype
  586. class ResizeNearestNeighborGrad(PrimitiveWithInfer):
  587. """
  588. Compute gradient of `ResizeNearestNeighbor` operator.
  589. Note:
  590. The shape of input parameter `size` must be (height, width).
  591. Args:
  592. align_corners (bool): Whether the centers of the 4 corner pixels of the input
  593. and output tensors are aligned. Default: False.
  594. """
  595. @prim_attr_register
  596. def __init__(self, align_corners=False):
  597. """Init ResizeNearestNeighborGrad"""
  598. self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
  599. def __infer__(self, grads, size):
  600. shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value']
  601. return {'shape': shp,
  602. 'dtype': grads['dtype'],
  603. 'value': None}
  604. class ROIAlignGrad(PrimitiveWithInfer):
  605. """
  606. ROIAlignGrad operator.
  607. Args:
  608. pooled_height (int): The output feature height.
  609. pooled_width (int): The output feature width.
  610. spatial_scale (float): The feature stride.
  611. sample_num (int): Number of sampling points. Default: 2.
  612. """
  613. @prim_attr_register
  614. def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
  615. """init ROIAlignGrad"""
  616. validator.check_value_type("pooled_height", pooled_height, [int], self.name)
  617. validator.check_value_type("pooled_width", pooled_width, [int], self.name)
  618. validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
  619. validator.check_value_type("sample_num", sample_num, [int], self.name)
  620. validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
  621. self.xdiff_shape = xdiff_shape
  622. self.pooled_height = pooled_height
  623. self.pooled_width = pooled_width
  624. self.spatial_scale = spatial_scale
  625. self.sample_num = sample_num
  626. def infer_shape(self, ydiff_shape, rois_shape):
  627. return self.xdiff_shape
  628. def infer_dtype(self, ydiff_type, rois_type):
  629. return ydiff_type
  630. class SigmoidGrad(PrimitiveWithInfer):
  631. """Gets the gradient of Sigmoid operation."""
  632. @prim_attr_register
  633. def __init__(self):
  634. pass
  635. def infer_shape(self, out, dout):
  636. return out
  637. def infer_dtype(self, out, dout):
  638. args = {'out': out, 'dout': dout}
  639. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  640. return out
  641. class HSigmoidGrad(PrimitiveWithInfer):
  642. """Gets the gradient of HSigmoid operation."""
  643. @prim_attr_register
  644. def __init__(self):
  645. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  646. def infer_shape(self, y_grad_shape, x_shape):
  647. return x_shape
  648. def infer_dtype(self, y_grad_dtype, x_dtype):
  649. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  650. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  651. return x_dtype
  652. class HSwishGrad(PrimitiveWithInfer):
  653. """Gets the gradient of HSwish operation."""
  654. @prim_attr_register
  655. def __init__(self):
  656. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  657. def infer_shape(self, y_grad_shape, x_shape):
  658. return x_shape
  659. def infer_dtype(self, y_grad_dtype, x_dtype):
  660. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  661. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  662. return x_dtype
  663. class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
  664. """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
  665. @prim_attr_register
  666. def __init__(self):
  667. """Init SigmoidCrossEntropyWithLogitsGrad"""
  668. self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
  669. def infer_shape(self, x_shape, y_shape, dout_shape):
  670. validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
  671. validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
  672. return x_shape
  673. def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
  674. args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
  675. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  676. return dout_dtype
  677. class SliceGrad(PrimitiveWithInfer):
  678. """Reverse of slice."""
  679. @prim_attr_register
  680. def __init__(self):
  681. """init SliceGrad"""
  682. self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
  683. def __infer__(self, dy, x, begin, size):
  684. dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
  685. dy_shape_len = len(dy_shape)
  686. for i in range(dy_shape_len):
  687. validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
  688. validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
  689. return {'shape': x_shape,
  690. 'dtype': x['dtype'],
  691. 'value': None}
  692. class SmoothL1LossGrad(PrimitiveWithInfer):
  693. """Computes gradient for prediction on SmoothL1Loss."""
  694. @prim_attr_register
  695. def __init__(self, sigma=1.0):
  696. pass
  697. def infer_shape(self, prediction, target, dloss):
  698. validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
  699. validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
  700. return prediction
  701. def infer_dtype(self, prediction, target, dloss):
  702. args = {"prediction": prediction, "target": target, 'dloss': dloss}
  703. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  704. return dloss
  705. class StridedSliceGrad(PrimitiveWithInfer):
  706. """
  707. Performs grad of StridedSlice operation.
  708. Args:
  709. begin_mask (int): Start indexing the slice. Default: 0.
  710. end_mask (int): End indexing the slice. Default: 0.
  711. ellipsis_mask (int): An int32 mask. Default: 0.
  712. new_axis_mask (int): An int32 mask. Default: 0.
  713. shrink_axis_mask (int): An int32 mask. Default: 0.
  714. Returns:
  715. Tensor, has the same shape of input.
  716. """
  717. @prim_attr_register
  718. def __init__(self,
  719. begin_mask=0,
  720. end_mask=0,
  721. ellipsis_mask=0,
  722. new_axis_mask=0,
  723. shrink_axis_mask=0):
  724. """init StrideSliceGrad"""
  725. validator.check_value_type('begin_mask', begin_mask, [int], self.name)
  726. validator.check_value_type('end_mask', end_mask, [int], self.name)
  727. validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
  728. validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
  729. validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
  730. self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
  731. def __infer__(self, dy, shapex, begin, end, strides):
  732. return {'shape': shapex['value'],
  733. 'dtype': dy['dtype'],
  734. 'value': None}
  735. class TanhGrad(PrimitiveWithInfer):
  736. """Computes gradient of hyperbolic tangent of input element-wise."""
  737. @prim_attr_register
  738. def __init__(self):
  739. pass
  740. def infer_shape(self, out, dout):
  741. return out
  742. def infer_dtype(self, out, dout):
  743. args = {"out": out, "dout": dout}
  744. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  745. return out
  746. class MirrorPadGrad(PrimitiveWithInfer):
  747. """Gradients of MirrorPad operation."""
  748. @prim_attr_register
  749. def __init__(self, mode="REFLECT"):
  750. """init MirrorPad"""
  751. validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
  752. self.mode = mode
  753. def __infer__(self, dout, paddings, x):
  754. validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
  755. validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
  756. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  757. return {'shape': x['shape'],
  758. 'dtype': dout['dtype'],
  759. 'value': None}
  760. class RefToEmbed(Primitive):
  761. r"""
  762. Make a key from Ref.
  763. The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
  764. and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
  765. Inputs:
  766. - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
  767. Outputs:
  768. symbolic_key, made from the Ref.
  769. Examples:
  770. >>> class Net(nn.Cell):
  771. >>> def __init__(self):
  772. >>> super(Net, self).__init__()
  773. >>> self.weight = mindspore.Parameter(1.0, name='weight')
  774. >>>
  775. >>> def construct(self):
  776. >>> key = RefToEmbed()(self.weight)
  777. >>> return key, self.weight
  778. """
  779. __mindspore_signature__ = (
  780. ('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD),
  781. )
  782. @prim_attr_register
  783. def __init__(self):
  784. pass