You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_grad_ops.py 69 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for gradients."""
  16. from .. import signature as sig
  17. from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
  18. from ..._checkparam import Validator as validator, Rel
  19. from .._utils import get_concat_offset
  20. from ...common import dtype as mstype
  21. from .. import functional as F
  22. class AbsGrad(PrimitiveWithInfer):
  23. """Computes gradients for abs operation."""
  24. @prim_attr_register
  25. def __init__(self):
  26. """Initialize AbsGrad"""
  27. def infer_shape(self, y, dy):
  28. return y
  29. def infer_dtype(self, y, dy):
  30. return y
  31. class ACosGrad(PrimitiveWithInfer):
  32. """
  33. Computes ACosGrad of input element-wise.
  34. Returns:
  35. Tensor, has the same type as input.
  36. """
  37. @prim_attr_register
  38. def __init__(self):
  39. """Initialize ACosGrad"""
  40. def infer_shape(self, x, dout):
  41. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  42. return x
  43. def infer_dtype(self, x, dout):
  44. args = {"x": x, "dout": dout}
  45. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  46. return x
  47. class AcoshGrad(PrimitiveWithInfer):
  48. """Performs grad of Acosh operation."""
  49. @prim_attr_register
  50. def __init__(self):
  51. """Initialize AcoshGrad"""
  52. def infer_shape(self, x, dout):
  53. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  54. return x
  55. def infer_dtype(self, x, dout):
  56. args = {"x": x, "dout": dout}
  57. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  58. return x
  59. class AsinGrad(PrimitiveWithInfer):
  60. """
  61. Computes AsinGrad of input element-wise.
  62. Returns:
  63. Tensor, has the same type as input.
  64. """
  65. @prim_attr_register
  66. def __init__(self):
  67. """Initialize AsinGrad"""
  68. def infer_shape(self, x, dout):
  69. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  70. return x
  71. def infer_dtype(self, x, dout):
  72. args = {"x": x, "dout": dout}
  73. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  74. return x
  75. class AsinhGrad(PrimitiveWithInfer):
  76. """Performs grad of Asinh operation."""
  77. @prim_attr_register
  78. def __init__(self):
  79. """Initialize AsinhGrad"""
  80. def infer_shape(self, x, dout):
  81. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  82. return x
  83. def infer_dtype(self, x, dout):
  84. args = {"x": x, "dout": dout}
  85. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  86. return x
  87. class ReciprocalGrad(PrimitiveWithInfer):
  88. """Performs grad of Reciprocal operation."""
  89. @prim_attr_register
  90. def __init__(self):
  91. """Initialize ReciprocalGrad"""
  92. def infer_shape(self, x_shape, dout_shape):
  93. validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
  94. return x_shape
  95. def infer_dtype(self, x_dtype, dout_dtype):
  96. args = {"x": x_dtype, "dout": dout_dtype}
  97. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
  98. return x_dtype
  99. class RsqrtGrad(PrimitiveWithInfer):
  100. """Performs grad of Rsqrt operation."""
  101. @prim_attr_register
  102. def __init__(self):
  103. """Initialize RsqrtGrad"""
  104. def infer_shape(self, x_shape, dout_shape):
  105. validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
  106. return x_shape
  107. def infer_dtype(self, x_dtype, dout_dtype):
  108. args = {"x": x_dtype, "dout": dout_dtype}
  109. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
  110. return x_dtype
  111. class SoftmaxGrad(PrimitiveWithInfer):
  112. """Performs grad of Softmax operation."""
  113. @prim_attr_register
  114. def __init__(self):
  115. """Initialize SoftmaxGrad"""
  116. def infer_shape(self, x_shape, dout_shape):
  117. validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
  118. return x_shape
  119. def infer_dtype(self, x_dtype, dout_dtype):
  120. args = {"x": x_dtype, "dout": dout_dtype}
  121. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
  122. return x_dtype
  123. class SqrtGrad(PrimitiveWithInfer):
  124. """Performs grad of Sqrt operation."""
  125. @prim_attr_register
  126. def __init__(self):
  127. """Initialize SqrtGrad"""
  128. def infer_shape(self, x_shape, dout_shape):
  129. validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
  130. return x_shape
  131. def infer_dtype(self, x_dtype, dout_dtype):
  132. args = {"x": x_dtype, "dout": dout_dtype}
  133. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
  134. return x_dtype
  135. class BatchNormGrad(PrimitiveWithInfer):
  136. """Performs grad of BatchNorm operation."""
  137. @prim_attr_register
  138. def __init__(self, is_training=False, epsilon=1e-5):
  139. self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
  140. self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
  141. self.add_prim_attr('data_format', "NCHW")
  142. def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
  143. validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
  144. return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
  145. def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
  146. return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
  147. class BiasAddGrad(Primitive):
  148. """Computes gradients of BiasAdd."""
  149. @prim_attr_register
  150. def __init__(self):
  151. self.init_prim_io_names(inputs=['dout'], outputs=['output'])
  152. self.add_prim_attr('data_format', 'NCHW')
  153. def __call__(self, d_output):
  154. raise NotImplementedError
  155. class KLDivLossGrad(PrimitiveWithInfer):
  156. """Computes gradients for `KLDivLoss` operation."""
  157. @prim_attr_register
  158. def __init__(self, reduction='mean'):
  159. self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
  160. def infer_shape(self, x_shape, y_shape, doutput_shape):
  161. validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
  162. return x_shape, y_shape
  163. def infer_dtype(self, x_type, y_type, doutput_type):
  164. args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
  165. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  166. return x_type, y_type
  167. class BinaryCrossEntropyGrad(PrimitiveWithInfer):
  168. """Computes gradients for `BinaryCrossEntropy` operation."""
  169. @prim_attr_register
  170. def __init__(self, reduction='mean'):
  171. self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
  172. def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
  173. validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
  174. if weight_shape:
  175. validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
  176. return x_shape
  177. def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
  178. args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
  179. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  180. if weight_type:
  181. validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
  182. return x_type
  183. class ConcatOffset(PrimitiveWithInfer):
  184. """primitive for computing Concat's gradient."""
  185. @prim_attr_register
  186. def __init__(self, N=2, axis=0):
  187. """Initialize ConcatOffset"""
  188. def __infer__(self, input_x):
  189. axis = self.axis
  190. x_shp = input_x['shape']
  191. x_type = input_x['dtype']
  192. offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
  193. self.add_prim_attr('T', x_type[0].element_type())
  194. offset_values = []
  195. for i in range(len(x_shp)):
  196. values = []
  197. for j in range(len(x_shp[0])):
  198. value = 0
  199. if j == axis:
  200. value = offset[i]
  201. values.append(value)
  202. offset_values.append(tuple(values))
  203. out = {'shape': None,
  204. 'dtype': None,
  205. 'value': tuple(offset_values)}
  206. return out
  207. class Conv2DBackpropFilter(PrimitiveWithInfer):
  208. """
  209. Computes the gradients of convolution with respect to the filter.
  210. Args:
  211. out_channel (int): The dimensionality of the output space.
  212. kernel_size (Union[int, tuple[int]]): The size of the convolution window.
  213. pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
  214. pad (int): The pad value to be filled. Default: 0.
  215. mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
  216. 2 deconvolution, 3 depthwise convolution. Default: 1.
  217. stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
  218. dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
  219. group (int): Splits input into groups. Default: 1.
  220. Returns:
  221. Tensor, the gradients of convolution.
  222. """
  223. @prim_attr_register
  224. def __init__(self,
  225. out_channel,
  226. kernel_size,
  227. pad_mode="valid",
  228. pad=0,
  229. pad_list=(0, 0, 0, 0),
  230. mode=1,
  231. stride=(1, 1),
  232. dilation=(1, 1, 1, 1),
  233. group=1):
  234. """Initialize Convolution"""
  235. self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
  236. self.out_channel = out_channel
  237. self.kernel_size = kernel_size
  238. self.mode = mode
  239. pad_mode = pad_mode.upper()
  240. self.add_prim_attr('pad_mode', pad_mode)
  241. self.pad = pad
  242. if isinstance(stride, tuple) and len(stride) == 4:
  243. self.stride = (stride[2], stride[3])
  244. self.add_prim_attr('stride', self.stride)
  245. self.dilation = dilation
  246. self.group = group
  247. self.add_prim_attr('groups', group)
  248. self.add_prim_attr('data_format', "NCHW")
  249. def __infer__(self, doutput, x, w_size):
  250. w_size_v = w_size['value']
  251. validator.check_value_type('w_size', w_size_v, [tuple], self.name)
  252. for i, dim_len in enumerate(w_size_v):
  253. validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
  254. args = {"x": x['dtype'], "doutput": doutput['dtype']}
  255. validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
  256. out = {
  257. 'value': None,
  258. 'shape': w_size_v,
  259. 'dtype': doutput['dtype'],
  260. }
  261. return out
  262. class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
  263. """
  264. Returns the gradient of filter for DepthwiseConv2dNative.
  265. Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
  266. Refer to class DepthwiseConv2dNative for more details.
  267. Args:
  268. channel_multiplier (int): The multipiler for the original output conv.
  269. kernel_size (int or tuple): The size of the conv kernel.
  270. mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution,
  271. 2 deconvolution,3 depthwise convolution. Defaul: 3.
  272. pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
  273. pad (int): The pad value to be filled. Default: 0.
  274. pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
  275. stride (int): The stride to be applied to the convolution filter. Default: 1.
  276. dilation (int): Specifies the space to use between kernel elements. Default: 1.
  277. group (int): Splits input into groups. Default: 1.
  278. Returns:
  279. Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
  280. """
  281. @prim_attr_register
  282. def __init__(self,
  283. channel_multiplier,
  284. kernel_size,
  285. pad_mode="valid",
  286. pad=0,
  287. pads=(0, 0, 0, 0),
  288. mode=3,
  289. stride=1,
  290. dilation=1,
  291. group=1):
  292. """Initialize Convolution"""
  293. self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
  294. self.channel_multiplier = channel_multiplier
  295. self.kernel_size = kernel_size
  296. self.mode = mode
  297. self.pad_mode = pad_mode
  298. self.pad = pad
  299. self.pads = pads
  300. self.stride = stride
  301. self.dilation = dilation
  302. self.group = group
  303. self.add_prim_attr('data_format', "NCHW")
  304. def __call__(self, x, w_size, dout):
  305. raise NotImplementedError
  306. def __infer__(self, x, w_size, dout):
  307. w_size_v = w_size['value']
  308. args = {'x': x['dtype'], 'dout': dout['dtype']}
  309. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  310. out = {
  311. 'value': None,
  312. 'shape': w_size_v,
  313. 'dtype': dout['dtype'],
  314. }
  315. return out
  316. class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
  317. """
  318. Returns the gradient of input for DepthwiseConv2dNative.
  319. Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
  320. Args:
  321. channel_multiplier (int): The multipiler for the original output conv.
  322. kernel_size (int or tuple): The size of the conv kernel.
  323. mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
  324. 2 deconvolution,3 depthwise convolution. Default: 3.
  325. pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
  326. pad (int): The pad value to be filled. Default: 0.
  327. pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
  328. stride (int): The stride to be applied to the convolution filter. Default: 1.
  329. dilation (int): Specifies the space to use between kernel elements. Default: 1.
  330. group (int): Splits input into groups. Default: 1.
  331. Returns:
  332. Tensor, the value is the gradient of input for DepthwiseConv2dNative.
  333. """
  334. @prim_attr_register
  335. def __init__(self,
  336. channel_multiplier,
  337. kernel_size,
  338. pad_mode="valid",
  339. pad=0,
  340. pads=(0, 0, 0, 0),
  341. mode=3,
  342. stride=1,
  343. dilation=1,
  344. group=1):
  345. """Initialize Convolution"""
  346. self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
  347. self.channel_multiplier = channel_multiplier
  348. self.kernel_size = kernel_size
  349. self.mode = mode
  350. self.pad_mode = pad_mode
  351. self.pad = pad
  352. self.pads = pads
  353. self.stride = stride
  354. self.dilation = dilation
  355. self.group = group
  356. self.add_prim_attr('data_format', "NCHW")
  357. def __call__(self, x_size, w, dout):
  358. raise NotImplementedError
  359. def __infer__(self, x_size, w, dout):
  360. args = {'w': w['dtype'], 'dout': dout['dtype']}
  361. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  362. x_size_v = x_size['value']
  363. out = {
  364. 'value': None,
  365. 'shape': x_size_v,
  366. 'dtype': dout['dtype'],
  367. }
  368. return out
  369. class DropoutGrad(PrimitiveWithInfer):
  370. """
  371. The gradient of Dropout. During training, randomly zeroes some of the elements
  372. of the input tensor with probability.
  373. Args:
  374. keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
  375. means dropping out 10% of input units.
  376. Inputs:
  377. - **shape** (tuple[int]) - The shape of target mask.
  378. Outputs:
  379. Tensor, the value of generated mask for input shape.
  380. Examples:
  381. >>> dropout_grad = P.DropoutGrad(keep_prob=0.5)
  382. >>> in = Tensor((20, 16, 50, 50))
  383. >>> out = dropout_grad(in)
  384. """
  385. @prim_attr_register
  386. def __init__(self, keep_prob=0.5):
  387. self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
  388. def infer_shape(self, dy_shape, mask_shape):
  389. return dy_shape
  390. def infer_dtype(self, dy_dtype, mask_dtype):
  391. valid_types = (mstype.float16, mstype.float32)
  392. validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
  393. validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
  394. validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
  395. return dy_dtype
  396. class FlattenGrad(PrimitiveWithInfer):
  397. """Performs gradients of Flatten."""
  398. @prim_attr_register
  399. def __init__(self):
  400. self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
  401. def __infer__(self, *args):
  402. out = {
  403. 'value': None,
  404. 'shape': args[1]['value'],
  405. 'dtype': args[0]['dtype'],
  406. }
  407. return out
  408. class FusedBatchNormGrad(Primitive):
  409. """Gradients of FusedBatchNorm operation."""
  410. @prim_attr_register
  411. def __init__(self, epsilon=0.0, momentum=0.1):
  412. self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'],
  413. outputs=['dx', 'bn_scale', 'bn_bias'])
  414. def __call__(self, dy, x, scale, save_mean, save_inv_variance):
  415. raise NotImplementedError
  416. class FusedBatchNormGradEx(PrimitiveWithInfer):
  417. """Gradients of FusedBatchNormEx operation."""
  418. @prim_attr_register
  419. def __init__(self, epsilon=0.0, momentum=0.1):
  420. self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'],
  421. outputs=['dx', 'bn_scale', 'bn_bias'])
  422. self.add_prim_attr('data_format', "NCHW")
  423. def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape):
  424. return (x_shape, scale_shape, scale_shape)
  425. def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type):
  426. return (x_type, scale_type, scale_type)
  427. class UniqueGrad(Primitive):
  428. """Gradients of Unique operation."""
  429. @prim_attr_register
  430. def __init__(self):
  431. self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
  432. def __call__(self, dy, x, scale, save_mean, save_inv_variance):
  433. raise NotImplementedError
  434. class BNTrainingReduceGrad(PrimitiveWithInfer):
  435. """Gradients of FusedBatchNorm operation."""
  436. @prim_attr_register
  437. def __init__(self, epsilon=0.0001):
  438. _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
  439. self.init_prim_io_names(inputs=_inputs, outputs=['y'])
  440. def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
  441. return grads
  442. def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
  443. return grads
  444. class BNTrainingUpdateGrad(PrimitiveWithInfer):
  445. """Gradients of FusedBatchNorm operation."""
  446. @prim_attr_register
  447. def __init__(self, epsilon=0.0001):
  448. self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
  449. outputs=['diff_scale', 'diff_offset'])
  450. def infer_shape(self, grads, x, batch_mean, batch_variance):
  451. return (batch_mean, batch_variance)
  452. def infer_dtype(self, grads, x, batch_mean, batch_variance):
  453. return (batch_mean, batch_variance)
  454. class GeluGrad(PrimitiveWithInfer):
  455. """Gradients of Gelu operation."""
  456. @prim_attr_register
  457. def __init__(self):
  458. """Initialize GeluGrad"""
  459. def infer_shape(self, y_backprop_shape, x_shape, y_shape):
  460. return x_shape
  461. def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
  462. validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
  463. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  464. validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
  465. return x_dtype
  466. class _PoolGrad(PrimitiveWithInfer):
  467. """Gradients of the max/avg pool operation."""
  468. @prim_attr_register
  469. def __init__(self, ksize, strides, padding="VALID"):
  470. self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
  471. validator.check_value_type('ksize', ksize, [int, tuple], self.name)
  472. validator.check_value_type('strides', strides, [int, tuple], self.name)
  473. self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
  474. self.add_prim_attr("padding", self.padding)
  475. self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
  476. if not self.is_maxpoolgradwithargmax:
  477. self.add_prim_attr('data_format', "NCHW")
  478. def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
  479. validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
  480. error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
  481. f"or a tuple of two or four positive int numbers, but got {arg_val}")
  482. if isinstance(arg_val, int):
  483. ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
  484. elif len(arg_val) == 2:
  485. ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
  486. elif len(arg_val) == 4:
  487. ret = arg_val
  488. else:
  489. raise error_msg
  490. # whether all elements of tuple are positive integers
  491. for item in ret:
  492. if not isinstance(item, int) or item <= 0:
  493. raise error_msg
  494. return ret
  495. self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
  496. self.add_prim_attr("ksize", self.ksize)
  497. self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
  498. self.add_prim_attr("strides", self.strides)
  499. class AvgPoolGrad(_PoolGrad):
  500. """Gradients of the avg pool operation for ge."""
  501. @prim_attr_register
  502. def __init__(self, ksize=1, strides=1, padding="VALID"):
  503. super(AvgPoolGrad, self).__init__(ksize, strides, padding)
  504. def __infer__(self, origin_input, dout):
  505. out = {
  506. 'value': None,
  507. 'shape': tuple(origin_input['value']),
  508. 'dtype': dout['dtype'],
  509. }
  510. return out
  511. class AvgPoolGradVm(_PoolGrad):
  512. """Gradients of the avg pool operation for vm."""
  513. @prim_attr_register
  514. def __init__(self, ksize=1, strides=1, padding="VALID"):
  515. super(AvgPoolGradVm, self).__init__(ksize, strides, padding)
  516. self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
  517. def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
  518. out = {
  519. 'value': None,
  520. 'shape': tuple(origin_input['value']),
  521. 'dtype': dout['dtype'],
  522. }
  523. return out
  524. class AvgPoolGradGpu(_PoolGrad):
  525. """Gradients of the avg pool operation for gpu."""
  526. @prim_attr_register
  527. def __init__(self, ksize=1, strides=1, padding="VALID"):
  528. super(AvgPoolGradGpu, self).__init__(ksize, strides, padding)
  529. def infer_shape(self, x1_shape, x2_shape, grad_shape):
  530. return x1_shape
  531. def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
  532. return x1_dtype
  533. class MaxPoolGrad(_PoolGrad):
  534. """Performs gradients of the max pool operation."""
  535. @prim_attr_register
  536. def __init__(self, ksize=1, strides=1, padding="VALID"):
  537. super(MaxPoolGrad, self).__init__(ksize, strides, padding)
  538. def infer_shape(self, x1_shape, x2_shape, grad_shape):
  539. return x1_shape
  540. def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
  541. return x1_dtype
  542. class MaxPoolGradGrad(_PoolGrad):
  543. r"""
  544. Performs gradients of the MaxPoolGrad operation.
  545. Args:
  546. ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
  547. is an int number that represents height and width are both ksize, or a tuple
  548. of two int numbers that represent height and width respectively. Default: 1.
  549. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  550. the height and width of movement are both strides, or a tuple of two int numbers that
  551. represent height and width of movement respectively. Default: 1.
  552. padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
  553. Default: "valid".
  554. - same: Adopts the way of completion. The height and width of the output will be the same as
  555. the input. The total number of padding will be calculated in horizontal and vertical
  556. directions and evenly distributed to top and bottom, left and right if possible.
  557. Otherwise, the last extra padding will be done from the bottom and the right side.
  558. - valid: Adopts the way of discarding. The possible largest height and width of output
  559. will be returned without padding. Extra pixels will be discarded.
  560. Inputs:
  561. - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be float16.
  562. - **origin_output** (Tensor) - Data type same as `origin_input`.
  563. - **grad** (Tensor) - Data type same as `origin_input`.
  564. Outputs:
  565. Tensor, with data type same as `origin_input`.
  566. """
  567. @prim_attr_register
  568. def __init__(self, ksize=1, strides=1, padding="VALID"):
  569. super(MaxPoolGradGrad, self).__init__(ksize, strides, padding)
  570. def infer_shape(self, x1_shape, x2_shape, grad_shape):
  571. return x1_shape
  572. def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
  573. args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
  574. validator.check_tensor_type_same(args, [mstype.float16], self.name)
  575. return x1_dtype
  576. class MaximumGrad(Primitive):
  577. """Grad for maximum."""
  578. @prim_attr_register
  579. def __init__(self, grad_x=True, grad_y=True):
  580. """Initialize MaximumGrad"""
  581. def __call__(self, x, y, dout):
  582. raise NotImplementedError
  583. class MaxPoolGradWithArgmax(_PoolGrad):
  584. """Computes the gradients of MaxPoolWithArgmax."""
  585. @prim_attr_register
  586. def __init__(self, ksize=1, strides=1, padding="VALID",):
  587. self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
  588. super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)
  589. def infer_shape(self, x_shape, grad_shape, argmax_shape):
  590. if not grad_shape:
  591. raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.")
  592. return x_shape
  593. def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
  594. return grad_dtype
  595. class MaxPoolGradGradWithArgmax(_PoolGrad):
  596. r"""
  597. Computes the gradients of MaxPoolGradWithArgmax.
  598. Args:
  599. ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
  600. is an int number that represents height and width are both ksize, or a tuple
  601. of two int numbers that represent height and width respectively. Default: 1.
  602. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  603. the height and width of movement are both strides, or a tuple of two int numbers that
  604. represent height and width of movement respectively. Default: 1.
  605. padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
  606. Default: "valid".
  607. - same: Adopts the way of completion. The height and width of the output will be the same as
  608. the input. The total number of padding will be calculated in horizontal and vertical
  609. directions and evenly distributed to top and bottom, left and right if possible.
  610. Otherwise, the last extra padding will be done from the bottom and the right side.
  611. - valid: Adopts the way of discarding. The possible largest height and width of output
  612. will be returned without padding. Extra pixels will be discarded.
  613. Inputs:
  614. - **x** (Tensor) - Tensor with data format "NCHW", data type must be float16.
  615. - **grad** (Tensor) - Data type same as `x`.
  616. - **argmax** (Tensor) - Data type must be uint16 or int64.
  617. Outputs:
  618. Tensor, with data type same as `x`.
  619. """
  620. @prim_attr_register
  621. def __init__(self, ksize=1, strides=1, padding="VALID"):
  622. self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
  623. super(MaxPoolGradGradWithArgmax, self).__init__(ksize, strides, padding)
  624. def infer_shape(self, x_shape, grad_shape, argmax_shape):
  625. if not grad_shape:
  626. raise TypeError("The dout of MaxPoolGradGradWithArgmax should be a Tensor.")
  627. return x_shape
  628. def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
  629. args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
  630. validator.check_tensor_type_same(args, [mstype.float16], self.name)
  631. return grad_dtype
  632. class MinimumGrad(Primitive):
  633. """Grad for minimum."""
  634. @prim_attr_register
  635. def __init__(self, grad_x=True, grad_y=True):
  636. """Initialize MinimumGrad"""
  637. def __call__(self, x, y, dout):
  638. raise NotImplementedError
  639. class L2NormalizeGrad(PrimitiveWithInfer):
  640. r"""
  641. Gradients of L2 normalize.
  642. Args:
  643. axis (int): The begin axis for the input to apply L2 normalize. Default: 0.
  644. epsilon (float): A small value added for numerical stability. Default: 1e-4.
  645. Inputs:
  646. - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
  647. - **out** (Tensor) - Must be the output of forward operator L2Normalize.
  648. - **dout** (Tensor) - The backprop of the next layer.
  649. Outputs:
  650. Tensor, gradients of L2Normalize `input_x`.
  651. """
  652. @prim_attr_register
  653. def __init__(self, axis=0, epsilon=1e-4):
  654. validator.check_value_type('axis', axis, [int], self.name)
  655. validator.check_value_type('epsilon', epsilon, [int, float], self.name)
  656. def infer_shape(self, input_x, out, dout):
  657. validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
  658. validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
  659. return input_x
  660. def infer_dtype(self, input_x, out, dout):
  661. args = {'input_x': input_x, 'out': out, 'dout': dout}
  662. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  663. return input_x
  664. class LayerNormGrad(Primitive):
  665. """
  666. Applies the layer normalization to the input array.
  667. This operator will calculate the input gradients of layernorm.
  668. Args:
  669. begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
  670. begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
  671. Returns:
  672. tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta).
  673. """
  674. @prim_attr_register
  675. def __init__(self, begin_norm_axis=1, begin_params_axis=1):
  676. """init"""
  677. self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
  678. self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
  679. def __call__(self, x, dy, variance, mean, gamma):
  680. raise NotImplementedError
  681. class LogSoftmaxGrad(PrimitiveWithInfer):
  682. """Computes gradient for the Log Softmax activation."""
  683. @prim_attr_register
  684. def __init__(self, axis=-1):
  685. """Initialize LogSoftmaxGrad"""
  686. validator.check_value_type("axis", axis, [int], self.name)
  687. def infer_shape(self, dout, logits):
  688. rank = len(logits)
  689. validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
  690. return logits
  691. def infer_dtype(self, dout, logits):
  692. validator.check_subclass("logits", logits, mstype.tensor, self.name)
  693. return logits
  694. class LSTMGradData(PrimitiveWithInfer):
  695. """Computes the data gradients of LSTM."""
  696. @prim_attr_register
  697. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  698. self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
  699. self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
  700. self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
  701. self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
  702. self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
  703. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  704. self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
  705. if bidirectional:
  706. self.num_directions = 2
  707. else:
  708. self.num_directions = 1
  709. def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
  710. hx_shape, cx_shape, reserve_shape, state_shape):
  711. # dhy and dcy should be same shape
  712. validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
  713. validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
  714. validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
  715. validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
  716. validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
  717. validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
  718. validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
  719. # dy: (seq_len, batch_size, hidden_size * num_directions)
  720. validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
  721. validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
  722. validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
  723. # (seq_len, batch_size, input_size)
  724. dx_shape = (y_shape[0], y_shape[1], self.input_size)
  725. dhx_shape = dhy_shape
  726. dcx_shape = dcy_shape
  727. return (dx_shape, dhx_shape, dcx_shape)
  728. def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
  729. hx_dtype, cx_dtype, reserve_dtype, state_dtype):
  730. args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
  731. validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
  732. return (dy_dtype, dy_dtype, dy_dtype)
  733. class LSTMGradWeight(PrimitiveWithInfer):
  734. """Computes the weight gradients of LSTM."""
  735. @prim_attr_register
  736. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  737. self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
  738. self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
  739. self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
  740. self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
  741. self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
  742. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  743. self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
  744. if bidirectional:
  745. self.num_directions = 2
  746. else:
  747. self.num_directions = 1
  748. def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
  749. weight_size = 0
  750. gate_size = 4 * self.hidden_size
  751. for layer in range(self.num_layers):
  752. for _ in range(self.num_directions):
  753. input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
  754. weight_size += gate_size * input_layer_size
  755. weight_size += gate_size * self.hidden_size
  756. if self.has_bias:
  757. weight_size += 2 * gate_size
  758. return (weight_size, 1, 1)
  759. def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
  760. return hx_dtype
  761. class LSTMGrad(PrimitiveWithInfer):
  762. """Computes the data and weight gradients of LSTM."""
  763. @prim_attr_register
  764. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  765. self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
  766. self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
  767. self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
  768. self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
  769. self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
  770. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  771. self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
  772. if bidirectional:
  773. self.num_directions = 2
  774. else:
  775. self.num_directions = 1
  776. def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
  777. dcy_shape, reserve_shape):
  778. # dhy and dcy should be same shape
  779. validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
  780. validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
  781. validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
  782. validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
  783. validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
  784. validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
  785. validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
  786. # dy: (seq_len, batch_size, hidden_size * num_directions)
  787. validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
  788. validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
  789. validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
  790. # (seq_len, batch_size, input_size)
  791. dx_shape = (y_shape[0], y_shape[1], self.input_size)
  792. dhx_shape = dhy_shape
  793. dcx_shape = dcy_shape
  794. weight_size = 0
  795. gate_size = 4 * self.hidden_size
  796. for layer in range(self.num_layers):
  797. for _ in range(self.num_directions):
  798. input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
  799. weight_size += gate_size * input_layer_size
  800. weight_size += gate_size * self.hidden_size
  801. if self.has_bias:
  802. weight_size += gate_size
  803. return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1))
  804. def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype,
  805. dcy_dtype, reserve_dtype):
  806. return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
  807. class DynamicRNNGrad(PrimitiveWithInfer):
  808. """Computes the input gradients of DynamicRNN."""
  809. @prim_attr_register
  810. def __init__(self,
  811. cell_type='LSTM',
  812. direction='UNIDIRECTIONAL',
  813. cell_depth=0,
  814. use_peephole=False,
  815. keep_prob=-1.0,
  816. cell_clip=-1.0,
  817. num_proj=0,
  818. time_major=True,
  819. forget_bias=0.0):
  820. self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
  821. self.add_prim_attr("io_format", "ND")
  822. def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
  823. c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
  824. validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
  825. num_step, batch_size, input_size = x_shape
  826. hidden_size = w_shape[-1] // 4
  827. if w_shape[-1] % 4 != 0:
  828. raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
  829. validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
  830. input_size + hidden_size, Rel.EQ, self.name)
  831. valid_shape = [num_step, batch_size, hidden_size]
  832. validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
  833. validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  834. validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  835. validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  836. validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  837. validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  838. validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  839. validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  840. validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  841. validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
  842. validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
  843. validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
  844. return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
  845. def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
  846. c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
  847. return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
  848. class PReLUGrad(PrimitiveWithInfer):
  849. r"""
  850. Gradients of PReLU operation.
  851. Note:
  852. 1-dimensional input_x is not supported.
  853. Inputs:
  854. - **y_backprop** (Tensor) - Representing the backprop of the next layer.
  855. - **input_x** (Tensor) - Must be the input `input_x` of forward operator PRelu.
  856. - **weight** (Tensor) - Float Tensor, w > 0, must be the input `weight` of forward operator PRelu.
  857. Outputs:
  858. Tensor, with the same type as `input_x`.
  859. """
  860. @prim_attr_register
  861. def __init__(self):
  862. pass
  863. def infer_shape(self, y_backprop_shape, A_shape, w_shape):
  864. if len(A_shape) == 1:
  865. raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.')
  866. return y_backprop_shape, w_shape
  867. def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
  868. valid_types = (mstype.float16, mstype.float32)
  869. validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
  870. validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
  871. validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
  872. return y_backprop_dtype, w_dtype
  873. class ReluGrad(Primitive):
  874. """Performs grad of Relu operation."""
  875. @prim_attr_register
  876. def __init__(self):
  877. """Initialize ReluGrad"""
  878. self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
  879. def __call__(self, y_backprop, x):
  880. raise NotImplementedError
  881. class ReLU6Grad(PrimitiveWithInfer):
  882. """Performs grad of ReLU6 operation."""
  883. @prim_attr_register
  884. def __init__(self):
  885. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  886. def __call__(self, y_grad, x):
  887. raise NotImplementedError
  888. def infer_shape(self, y_grad_shape, x_shape):
  889. return x_shape
  890. def infer_dtype(self, y_grad_dtype, x_dtype):
  891. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  892. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  893. return x_dtype
  894. class ReluGradV2(PrimitiveWithInfer):
  895. """Performs grad of ReLUV2 operation."""
  896. @prim_attr_register
  897. def __init__(self):
  898. self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
  899. def __call__(self, gradients, mask):
  900. raise NotImplementedError
  901. def infer_shape(self, gradients_shape, mask_shape):
  902. return gradients_shape
  903. def infer_dtype(self, gradients_dtype, mask_dtype):
  904. validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
  905. validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
  906. return gradients_dtype
  907. class EluGrad(PrimitiveWithInfer):
  908. """Performs grad of Elu operation."""
  909. @prim_attr_register
  910. def __init__(self):
  911. """Initialize EluGrad"""
  912. def infer_shape(self, y_grad_shape, x_shape):
  913. return x_shape
  914. def infer_dtype(self, y_grad_dtype, x_dtype):
  915. args = {'y_grad': y_grad_dtype, 'x': x_dtype}
  916. validator.check_tensor_type_same(args, mstype.float_type, self.name)
  917. return x_dtype
  918. class ResizeBilinearGrad(PrimitiveWithInfer):
  919. """Performs grad of ResizeBilinear operation."""
  920. @prim_attr_register
  921. def __init__(self, align_corners=False):
  922. """init"""
  923. def infer_shape(self, dout_shape, orig_shape):
  924. return orig_shape
  925. def infer_dtype(self, dout_dtype, orig_type):
  926. return orig_type
  927. class ResizeNearestNeighborGrad(PrimitiveWithInfer):
  928. """
  929. Compute gradient of `ResizeNearestNeighbor` operator.
  930. Note:
  931. The shape of input parameter `size` must be (height, width).
  932. Args:
  933. align_corners (bool): Whether the centers of the 4 corner pixels of the input
  934. and output tensors are aligned. Default: False.
  935. """
  936. @prim_attr_register
  937. def __init__(self, align_corners=False):
  938. """Initialize ResizeNearestNeighborGrad"""
  939. self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
  940. def __infer__(self, grads, size):
  941. shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value']
  942. return {'shape': shp,
  943. 'dtype': grads['dtype'],
  944. 'value': None}
  945. class ROIAlignGrad(PrimitiveWithInfer):
  946. """
  947. ROIAlignGrad operator.
  948. Args:
  949. pooled_height (int): The output feature height.
  950. pooled_width (int): The output feature width.
  951. spatial_scale (float): The feature stride.
  952. sample_num (int): Number of sampling points. Default: 2.
  953. """
  954. @prim_attr_register
  955. def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
  956. """Initialize ROIAlignGrad"""
  957. validator.check_value_type("pooled_height", pooled_height, [int], self.name)
  958. validator.check_value_type("pooled_width", pooled_width, [int], self.name)
  959. validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
  960. validator.check_value_type("sample_num", sample_num, [int], self.name)
  961. validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
  962. self.xdiff_shape = xdiff_shape
  963. self.pooled_height = pooled_height
  964. self.pooled_width = pooled_width
  965. self.spatial_scale = spatial_scale
  966. self.sample_num = sample_num
  967. def infer_shape(self, ydiff_shape, rois_shape):
  968. return self.xdiff_shape
  969. def infer_dtype(self, ydiff_type, rois_type):
  970. return ydiff_type
  971. class SigmoidGrad(PrimitiveWithInfer):
  972. """Gets the gradient of Sigmoid operation."""
  973. @prim_attr_register
  974. def __init__(self):
  975. pass
  976. def infer_shape(self, out, dout):
  977. return out
  978. def infer_dtype(self, out, dout):
  979. args = {'out': out, 'dout': dout}
  980. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  981. return out
  982. class HSigmoidGrad(PrimitiveWithInfer):
  983. """Gets the gradient of HSigmoid operation."""
  984. @prim_attr_register
  985. def __init__(self):
  986. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  987. def infer_shape(self, y_grad_shape, x_shape):
  988. return x_shape
  989. def infer_dtype(self, y_grad_dtype, x_dtype):
  990. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  991. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  992. return x_dtype
  993. class HSwishGrad(PrimitiveWithInfer):
  994. """Gets the gradient of HSwish operation."""
  995. @prim_attr_register
  996. def __init__(self):
  997. self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
  998. def infer_shape(self, y_grad_shape, x_shape):
  999. return x_shape
  1000. def infer_dtype(self, y_grad_dtype, x_dtype):
  1001. validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
  1002. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  1003. return x_dtype
  1004. class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
  1005. """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
  1006. @prim_attr_register
  1007. def __init__(self):
  1008. """Initialize SigmoidCrossEntropyWithLogitsGrad"""
  1009. self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
  1010. def infer_shape(self, x_shape, y_shape, dout_shape):
  1011. validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
  1012. validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
  1013. return x_shape
  1014. def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
  1015. args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
  1016. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1017. return dout_dtype
  1018. class SliceGrad(PrimitiveWithInfer):
  1019. """Reverse of slice."""
  1020. @prim_attr_register
  1021. def __init__(self):
  1022. """Initialize SliceGrad"""
  1023. self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
  1024. def __infer__(self, dy, x, begin, size):
  1025. dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
  1026. dy_shape_len = len(dy_shape)
  1027. for i in range(dy_shape_len):
  1028. validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
  1029. validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
  1030. return {'shape': x_shape,
  1031. 'dtype': x['dtype'],
  1032. 'value': None}
  1033. class SmoothL1LossGrad(PrimitiveWithInfer):
  1034. """Computes gradient for prediction on SmoothL1Loss."""
  1035. @prim_attr_register
  1036. def __init__(self, beta=1.0):
  1037. pass
  1038. def infer_shape(self, prediction, target, dloss):
  1039. validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
  1040. validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
  1041. return prediction
  1042. def infer_dtype(self, prediction, target, dloss):
  1043. args = {"prediction": prediction, "target": target, 'dloss': dloss}
  1044. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1045. return dloss
  1046. class StridedSliceGrad(PrimitiveWithInfer):
  1047. """
  1048. Performs grad of StridedSlice operation.
  1049. Args:
  1050. begin_mask (int): Start indexing the slice. Default: 0.
  1051. end_mask (int): End indexing the slice. Default: 0.
  1052. ellipsis_mask (int): An int32 mask. Default: 0.
  1053. new_axis_mask (int): An int32 mask. Default: 0.
  1054. shrink_axis_mask (int): An int32 mask. Default: 0.
  1055. Returns:
  1056. Tensor, has the same shape of input.
  1057. """
  1058. @prim_attr_register
  1059. def __init__(self,
  1060. begin_mask=0,
  1061. end_mask=0,
  1062. ellipsis_mask=0,
  1063. new_axis_mask=0,
  1064. shrink_axis_mask=0):
  1065. """Initialize StrideSliceGrad"""
  1066. validator.check_value_type('begin_mask', begin_mask, [int], self.name)
  1067. validator.check_value_type('end_mask', end_mask, [int], self.name)
  1068. validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
  1069. validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
  1070. validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
  1071. self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
  1072. def __infer__(self, dy, shapex, begin, end, strides):
  1073. args = {"dy": dy['dtype']}
  1074. validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
  1075. for idx, item in enumerate(shapex['value']):
  1076. validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
  1077. for idx, item in enumerate(begin['value']):
  1078. validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
  1079. for idx, item in enumerate(end['value']):
  1080. validator.check_value_type("end[%d]" % idx, item, [int], self.name)
  1081. for idx, item in enumerate(strides['value']):
  1082. validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
  1083. return {'shape': shapex['value'],
  1084. 'dtype': dy['dtype'],
  1085. 'value': None}
  1086. class SoftplusGrad(PrimitiveWithInfer):
  1087. """Computes gradient for the Log Softmax activation."""
  1088. @prim_attr_register
  1089. def __init__(self):
  1090. self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output'])
  1091. def infer_shape(self, dout_shape, x_shape):
  1092. validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
  1093. return x_shape
  1094. def infer_dtype(self, dout_dtype, x_dtype):
  1095. args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
  1096. validator.check_tensor_type_same(args, mstype.float_type, self.name)
  1097. return x_dtype
  1098. class TanhGrad(PrimitiveWithInfer):
  1099. """Computes gradient of hyperbolic tangent of input element-wise."""
  1100. @prim_attr_register
  1101. def __init__(self):
  1102. pass
  1103. def infer_shape(self, out, dout):
  1104. return out
  1105. def infer_dtype(self, out, dout):
  1106. args = {"out": out, "dout": dout}
  1107. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1108. return out
  1109. class MirrorPadGrad(PrimitiveWithInfer):
  1110. """Gradients of MirrorPad operation."""
  1111. @prim_attr_register
  1112. def __init__(self, mode="REFLECT"):
  1113. """Initialize MirrorPad"""
  1114. validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
  1115. self.mode = mode
  1116. def __infer__(self, dout, paddings):
  1117. validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
  1118. validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
  1119. validator.check("paddings rank", len(paddings['shape']), "expected", 2, Rel.EQ, self.name)
  1120. validator.check("paddings dim_1", paddings['shape'][1], "expected", 2, Rel.EQ, self.name)
  1121. if paddings['value'] is None:
  1122. raise ValueError(f"For {self.name}, paddings must be const.")
  1123. paddings_value = paddings['value'].asnumpy()
  1124. y_shape = ()
  1125. dout_shape = dout['shape']
  1126. for i, val in enumerate(dout_shape):
  1127. y_shape += (val - paddings_value[i][0] - paddings_value[i][1],)
  1128. return {'shape': y_shape,
  1129. 'dtype': dout['dtype'],
  1130. 'value': None}
  1131. class EmbeddingLookupCommGrad(PrimitiveWithInfer):
  1132. """
  1133. Performs the gradient for the communication part of EmbeddingLookup operator.
  1134. This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
  1135. this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
  1136. """
  1137. @prim_attr_register
  1138. def __init__(self):
  1139. self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
  1140. self.add_prim_attr('primitive_target', 'CPU')
  1141. def __infer__(self, dy, split_num):
  1142. """
  1143. This primitive is implemented by three steps:
  1144. 1) Splits the 'dy' along dimension 0 into 'split_num' parts.
  1145. 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
  1146. 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
  1147. along dimension 0.
  1148. The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
  1149. """
  1150. dy_shape = tuple(dy['shape'])
  1151. split_num_value = split_num['value']
  1152. validator.check_value_type("split_num_value", split_num_value, [int], self.name)
  1153. dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
  1154. return {'shape': dy_shape_all,
  1155. 'dtype': dy['dtype'],
  1156. 'value': None}
  1157. class RefToEmbed(Primitive):
  1158. r"""
  1159. Make a key from Ref.
  1160. The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
  1161. and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
  1162. Inputs:
  1163. - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
  1164. Outputs:
  1165. symbolic_key, made from the Ref.
  1166. Examples:
  1167. >>> class Net(nn.Cell):
  1168. >>> def __init__(self):
  1169. >>> super(Net, self).__init__()
  1170. >>> self.weight = mindspore.Parameter(1.0, name='weight')
  1171. >>>
  1172. >>> def construct(self):
  1173. >>> key = RefToEmbed()(self.weight)
  1174. >>> return key, self.weight
  1175. """
  1176. __mindspore_signature__ = (
  1177. sig.make_sig('variable', sig.sig_rw.RW_REF),
  1178. )
  1179. @prim_attr_register
  1180. def __init__(self):
  1181. pass
  1182. class AtanGrad(PrimitiveWithInfer):
  1183. """
  1184. Computes AtanGrad of input element-wise.
  1185. Returns:
  1186. Tensor, has the same type as input.
  1187. """
  1188. @prim_attr_register
  1189. def __init__(self):
  1190. """Initialize AtanGrad"""
  1191. def infer_shape(self, x, dout):
  1192. validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
  1193. return x
  1194. def infer_dtype(self, x, dout):
  1195. args = {"x": x, "dout": dout}
  1196. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1197. return x
  1198. class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
  1199. """Computes the state gradients of BasicLSTMCell."""
  1200. @prim_attr_register
  1201. def __init__(self, forget_bias, activation):
  1202. self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
  1203. self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
  1204. self.add_prim_attr("io_format", "ND")
  1205. def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
  1206. # dhy and dcy should be same shape
  1207. validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
  1208. validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1209. validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1210. validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1211. validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1212. validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1213. validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1214. validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
  1215. validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
  1216. validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
  1217. validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
  1218. validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
  1219. validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
  1220. validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
  1221. validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
  1222. dgate_shape = (c_shape[0], 4 * c_shape[1])
  1223. dct_1_shape = c_shape
  1224. return (dgate_shape, dct_1_shape)
  1225. def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
  1226. validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
  1227. validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name)
  1228. validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name)
  1229. validator.check_subclass("it", it_dtype, [mstype.tensor], self.name)
  1230. validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name)
  1231. validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name)
  1232. validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name)
  1233. validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name)
  1234. validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
  1235. validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
  1236. validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
  1237. validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
  1238. validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
  1239. validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
  1240. validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
  1241. validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
  1242. return (c_dtype, c_dtype)
  1243. class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
  1244. """Computes the weight gradients of BasicLSTM."""
  1245. @prim_attr_register
  1246. def __init__(self):
  1247. self.add_prim_attr("io_format", "HWCN")
  1248. def infer_shape(self, x_shape, h_shape, dgate_shape):
  1249. validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
  1250. validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
  1251. validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
  1252. validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
  1253. validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
  1254. validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
  1255. input_size = x_shape[1]
  1256. hidden_size = h_shape[1]
  1257. dw_shape = (input_size + hidden_size, 4 * hidden_size)
  1258. db_shape = (4 * hidden_size,)
  1259. return (dw_shape, db_shape)
  1260. def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
  1261. validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
  1262. validator.check_subclass("h", h_dtype, mstype.tensor, self.name)
  1263. validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
  1264. validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
  1265. validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
  1266. validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
  1267. return (x_dtype, x_dtype)
  1268. class BasicLSTMCellInputGrad(PrimitiveWithInfer):
  1269. """Computes the input gradients of BasicLSTM."""
  1270. @prim_attr_register
  1271. def __init__(self, keep_prob):
  1272. self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
  1273. self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
  1274. self.add_prim_attr("io_format", "ND")
  1275. def infer_shape(self, dgate_shape, w_shape):
  1276. validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
  1277. validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
  1278. validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
  1279. batch_size = dgate_shape[0]
  1280. hidden_size = dgate_shape[1] // 4
  1281. input_size = w_shape[0] - hidden_size
  1282. dxt_shape = (batch_size, input_size)
  1283. dht_shape = (batch_size, hidden_size)
  1284. return (dxt_shape, dht_shape)
  1285. def infer_dtype(self, dgate_dtype, w_dtype):
  1286. validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
  1287. validator.check_subclass("w", w_dtype, mstype.tensor, self.name)
  1288. validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
  1289. validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
  1290. return (dgate_dtype, dgate_dtype)
  1291. class InvGrad(PrimitiveWithInfer):
  1292. """Computes gradients for inv operation."""
  1293. @prim_attr_register
  1294. def __init__(self):
  1295. pass
  1296. def infer_shape(self, x, grad):
  1297. validator.check("x_shape", x, "grad_shape", grad, Rel.EQ, self.name)
  1298. return x
  1299. def infer_dtype(self, x, grad):
  1300. validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
  1301. validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
  1302. return x
  1303. class LRNGrad(PrimitiveWithInfer):
  1304. """Computes gradients for LRN operation."""
  1305. @prim_attr_register
  1306. def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
  1307. self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
  1308. validator.check_value_type("depth_radius", depth_radius, [int], self.name)
  1309. validator.check_value_type("bias", bias, [float], self.name)
  1310. validator.check_value_type("alpha", alpha, [float], self.name)
  1311. validator.check_value_type("beta", beta, [float], self.name)
  1312. def infer_dtype(self, grads, x, y):
  1313. args = {"grads": grads, "x": x, "y": y}
  1314. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name)
  1315. return x
  1316. def infer_shape(self, grads, x, y):
  1317. return x