You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 36 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. import mindspore as ms
  19. from mindspore import nn
  20. from .. import functional as F
  21. from .. import operations as P
  22. from ..operations import _grad_ops as G
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  25. from .grad_base import bprop_getters
  26. from ..primitive import constexpr
  27. from ..composite.multitype_ops import _constexpr_utils as const_utils
  28. from ..operations._inner_ops import DynamicStitch, DynamicBroadcastGradientArgs
  29. from ...common import Tensor
  30. from .._utils.utils import is_shape_unknown
  31. shape_op = P.Shape()
  32. dyn_shape_op = P.DynamicShape()
  33. reduce_prod = P.ReduceProd()
  34. reduce_sum = P.ReduceSum()
  35. reshape = P.Reshape()
  36. tile = P.Tile()
  37. is_sub_class = P.IsSubClass()
  38. to_array = P.TupleToArray()
  39. real_div = P.RealDiv()
  40. def dyn_binop_grad_common(x, y, dx, dy):
  41. """
  42. Common grad definition for binary operations when the input is dynamic shape.
  43. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  44. """
  45. shape_of_x = dyn_shape_op(x)
  46. shape_of_y = dyn_shape_op(y)
  47. rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
  48. dx = reduce_sum(dx, rx)
  49. dy = reduce_sum(dy, ry)
  50. reduce_dx = reshape(dx, shape_of_x)
  51. reduce_dy = reshape(dy, shape_of_y)
  52. return reduce_dx, reduce_dy
  53. def binop_grad_common(x, y, dx, dy):
  54. """
  55. Common grad definition for binary operations.
  56. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  57. """
  58. shape_of_x = shape_op(x)
  59. shape_of_y = shape_op(y)
  60. # if input shape is the same as dout shape, do not need to reduce
  61. reduce_dx = dx
  62. reduce_dy = dy
  63. if not (is_shape_unknown(shape_of_x) or is_shape_unknown(shape_of_y)):
  64. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  65. if rx[0]:
  66. # if dx is scalar whose shape is (), do not need reduce
  67. if shape_op(dx):
  68. dx = reduce_sum(dx, rx[0])
  69. reduce_dx = reshape(dx, shape_of_x)
  70. if rx[1]:
  71. # if dy is scalar whose shape is (), do not need reduce
  72. if shape_op(dy):
  73. dy = reduce_sum(dy, rx[1])
  74. reduce_dy = reshape(dy, shape_of_y)
  75. return reduce_dx, reduce_dy
  76. if not shape_of_x or not shape_of_y:
  77. # x or y is scalar
  78. if not shape_of_x:
  79. reduce_dx = reduce_sum(dx, ())
  80. if not shape_of_y:
  81. reduce_dy = reduce_sum(dy, ())
  82. return reduce_dx, reduce_dy
  83. return dyn_binop_grad_common(x, y, dx, dy)
  84. def _dyn_reduced_shape(input_shape, axis):
  85. """Dynamic reduce shape"""
  86. input_shape = P.Cast()(input_shape, ms.int32)
  87. if isinstance(axis, Tensor):
  88. input_rank = P.Rank()(input_shape)
  89. real_axis = (axis + input_rank) % input_rank
  90. axis_shape = shape_op(real_axis)
  91. else:
  92. real_axis = ()
  93. input_rank = len(input_shape)
  94. if isinstance(axis, int):
  95. axis = (axis,)
  96. elif not axis:
  97. axis = range(input_rank)
  98. for i in axis:
  99. real_axis += ((i + input_rank)%input_rank,)
  100. axis_shape = (len(real_axis),)
  101. return DynamicStitch()([to_array(range(input_rank)), to_array(axis)],
  102. [input_shape, P.Fill()(ms.int32, axis_shape, 1)])
  103. def _sum_grad(x, axis, dout):
  104. """Grad definition for `Sum` operation."""
  105. input_shape = shape_op(x)
  106. if is_shape_unknown(input_shape):
  107. input_shape = dyn_shape_op(x)
  108. output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis)
  109. tile_scaling = real_div(input_shape, output_shape_kept_dims)
  110. else:
  111. output_shape_kept_dims = reduced_shape(input_shape, axis)
  112. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  113. grad = reshape(dout, output_shape_kept_dims)
  114. return tile(grad, tile_scaling)
  115. def _min_or_max_grad(x, axis, out, dout):
  116. """Grad definition for `Min` and `Max` operations."""
  117. input_shape = shape_op(x)
  118. output_shape_kept_dims = reduced_shape(input_shape, axis)
  119. y = reshape(out, output_shape_kept_dims)
  120. grad = reshape(dout, output_shape_kept_dims)
  121. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  122. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  123. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  124. return indicators / num_selected * grad
  125. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  126. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  127. expand = P.ExpandDims()
  128. x_shape = F.shape(x)
  129. x_dim = len(x_shape)
  130. x_axis = axis
  131. if x_axis < 0:
  132. x_axis = axis + x_dim
  133. onehot_axis = x_axis
  134. depth = x_shape[x_axis]
  135. if keep_dims:
  136. dout_expand = dout[1]
  137. out = op(x)
  138. else:
  139. dout_expand = expand(dout[1], onehot_axis)
  140. if onehot_axis >= len(shape_op(out[0])):
  141. onehot_axis = -1
  142. onehot = P.OneHot(onehot_axis)
  143. type_x = F.dtype(x)
  144. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  145. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  146. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  147. return dx
  148. @bprop_getters.register(P.MatMul)
  149. def bprop_matmul(self):
  150. """Grad definition for `MatMul` operation."""
  151. ta = self.transpose_a
  152. tb = self.transpose_b
  153. mul1 = P.MatMul(transpose_a=(ta and tb),
  154. transpose_b=(ta or (not tb)))
  155. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  156. transpose_b=(ta and tb))
  157. def bprop(x, w, out, dout):
  158. if ta:
  159. dx = mul1(w, dout)
  160. else:
  161. dx = mul1(dout, w)
  162. if tb:
  163. dw = mul2(dout, x)
  164. else:
  165. dw = mul2(x, dout)
  166. return dx, dw
  167. return bprop
  168. @bprop_getters.register(P.BatchMatMul)
  169. def bprop_batchmatmul(self):
  170. """Grad definition for `BatchMatMul` operation."""
  171. ta = self.transpose_a
  172. tb = self.transpose_b
  173. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  174. transpose_b=(ta or (not tb)))
  175. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  176. transpose_b=(ta and tb))
  177. def bprop(x, w, out, dout):
  178. if ta:
  179. dx = mul1(w, dout)
  180. else:
  181. dx = mul1(dout, w)
  182. if tb:
  183. dw = mul2(dout, x)
  184. else:
  185. dw = mul2(x, dout)
  186. return dx, dw
  187. return bprop
  188. @bprop_getters.register(P.Add)
  189. def get_bprop_add(self):
  190. """Grad definition for `Add` operation."""
  191. def bprop(x, y, out, dout):
  192. return binop_grad_common(x, y, dout, dout)
  193. return bprop
  194. @bprop_getters.register(P.TensorAdd)
  195. def get_bprop_tensor_add(self):
  196. """Grad definition for `Add` operation."""
  197. def bprop(x, y, out, dout):
  198. return binop_grad_common(x, y, dout, dout)
  199. return bprop
  200. @bprop_getters.register(P.MatrixInverse)
  201. def get_bprop_matrix_inverse(self):
  202. """Grad definition for `MatrixInverse` operation."""
  203. matmul_x1 = nn.MatMul(transpose_x1=True)
  204. matmul_x2 = nn.MatMul(transpose_x2=True)
  205. neg = P.Neg()
  206. def bprop(x, out, dout):
  207. dx = matmul_x2(dout, out)
  208. dx = matmul_x1(out, dx)
  209. dx = neg(dx)
  210. return (dx,)
  211. return bprop
  212. @bprop_getters.register(P.Neg)
  213. def get_bprop_neg(self):
  214. """Grad definition for `Neg` operation."""
  215. neg_grad = P.Neg()
  216. def bprop(x, out, dout):
  217. dx = neg_grad(dout)
  218. return (dx,)
  219. return bprop
  220. @bprop_getters.register(P.Sub)
  221. def get_bprop_sub(self):
  222. """Grad definition for `Sub` operation."""
  223. neg_func = P.Neg()
  224. def bprop(x, y, out, dout):
  225. return binop_grad_common(x, y, dout, neg_func(dout))
  226. return bprop
  227. @bprop_getters.register(P.Mul)
  228. def get_bprop_mul(self):
  229. """Grad definition for `Mul` operation."""
  230. mul_func = P.Mul()
  231. def bprop(x, y, out, dout):
  232. bc_dx = mul_func(y, dout)
  233. bc_dy = mul_func(x, dout)
  234. return binop_grad_common(x, y, bc_dx, bc_dy)
  235. return bprop
  236. @bprop_getters.register(P.RealDiv)
  237. def get_bprop_real_div(self):
  238. """Grad definition for `RealDiv` operation."""
  239. div_op = P.RealDiv()
  240. neg = P.Neg()
  241. mul_op = P.Mul()
  242. def bprop(x, y, out, dout):
  243. bc_x = div_op(dout, y)
  244. bc_y = neg(mul_op(bc_x, out))
  245. return binop_grad_common(x, y, bc_x, bc_y)
  246. return bprop
  247. @bprop_getters.register(P.Div)
  248. def get_bprop_div(self):
  249. """Grad definition for `Div` operation."""
  250. div_op = P.Div()
  251. neg = P.Neg()
  252. mul_op = P.Mul()
  253. def bprop(x, y, out, dout):
  254. bc_x = div_op(dout, y)
  255. bc_y = neg(mul_op(bc_x, out))
  256. return binop_grad_common(x, y, bc_x, bc_y)
  257. return bprop
  258. @bprop_getters.register(P.DivNoNan)
  259. def get_bprop_div_no_nan(self):
  260. """Grad definition for `DivNoNan` operation."""
  261. div_no_nan_op = P.DivNoNan()
  262. neg = P.Neg()
  263. mul_op = P.Mul()
  264. def bprop(x, y, out, dout):
  265. bc_x = div_no_nan_op(dout, y)
  266. bc_y = neg(mul_op(bc_x, out))
  267. return binop_grad_common(x, y, bc_x, bc_y)
  268. return bprop
  269. @bprop_getters.register(P.Xdivy)
  270. def get_bprop_xdivy(self):
  271. """Grad definition for `Xdivy` operation."""
  272. div_op = P.Xdivy()
  273. def bprop(x, y, out, dout):
  274. x_dtype = F.dtype(x)
  275. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  276. bc_x = div_op(not_zero_x, y) * dout
  277. bc_y = div_op(-x, F.square(y)) * dout
  278. return binop_grad_common(x, y, bc_x, bc_y)
  279. return bprop
  280. @bprop_getters.register(P.Floor)
  281. def get_bprop_floor(self):
  282. """Grad definition for `floor` operation."""
  283. fill_ = P.Fill()
  284. shape_ = P.Shape()
  285. dtype_ = P.DType()
  286. def bprop(x, out, dout):
  287. bc_x = fill_(dtype_(x), shape_(x), 0.)
  288. return (bc_x,)
  289. return bprop
  290. @bprop_getters.register(P.Ceil)
  291. def get_bprop_ceil(self):
  292. """Grad definition for `ceil` operation."""
  293. fill_ = P.Fill()
  294. shape_ = P.Shape()
  295. dtype_ = P.DType()
  296. def bprop(x, out, dout):
  297. bc_x = fill_(dtype_(x), shape_(x), 0.)
  298. return (bc_x,)
  299. return bprop
  300. @bprop_getters.register(P.FloorDiv)
  301. def get_bprop_floordiv(self):
  302. """Grad definition for `FloorDiv` operation."""
  303. def bprop(x, y, out, dout):
  304. return zeros_like(x), zeros_like(y)
  305. return bprop
  306. @bprop_getters.register(P.FloorMod)
  307. def get_bprop_floormod(self):
  308. """Grad definition for `FloorMod` operation."""
  309. def bprop(x, y, out, dout):
  310. bc_x = dout
  311. bc_y = -dout * (x // y)
  312. return binop_grad_common(x, y, bc_x, bc_y)
  313. return bprop
  314. @bprop_getters.register(P.TruncateDiv)
  315. def get_bprop_truncate_div(self):
  316. """Grad definition for `TruncateDiv` operation."""
  317. def bprop(x, y, out, dout):
  318. return zeros_like(x), zeros_like(y)
  319. return bprop
  320. @bprop_getters.register(P.TruncateMod)
  321. def get_bprop_truncate_mod(self):
  322. """Grad definition for `TruncateMod` operation."""
  323. div_op = P.TruncateDiv()
  324. def bprop(x, y, out, dout):
  325. bc_x = dout
  326. bc_y = -dout * div_op(x, y)
  327. return binop_grad_common(x, y, bc_x, bc_y)
  328. return bprop
  329. @bprop_getters.register(P.Mod)
  330. def get_bprop_mod(self):
  331. """Grad definition for `Mod` operation."""
  332. def bprop(x, y, out, dout):
  333. bc_x = dout
  334. bc_y = -dout * (x // y)
  335. return binop_grad_common(x, y, bc_x, bc_y)
  336. return bprop
  337. @bprop_getters.register(P.Square)
  338. def get_bprop_square(self):
  339. """Grad definition for `Square` operation."""
  340. mul_func = P.Mul()
  341. fill_func = P.Fill()
  342. dtype = P.DType()
  343. def bprop(x, out, dout):
  344. temp = mul_func(dout, x)
  345. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  346. return (dx,)
  347. return bprop
  348. @bprop_getters.register(P.SquaredDifference)
  349. def get_bprop_squared_difference(self):
  350. """Grad definition for `SquaredDifference` operation."""
  351. neg = P.Neg()
  352. def bprop(x, y, out, dout):
  353. x_grad = 2 * dout * (x - y)
  354. bc_x = x_grad
  355. bc_y = neg(x_grad)
  356. return binop_grad_common(x, y, bc_x, bc_y)
  357. return bprop
  358. @bprop_getters.register(P.Xlogy)
  359. def get_bprop_xlogy(self):
  360. """Grad definition for `Xlogy` operation."""
  361. log_op = P.Xlogy()
  362. div_op = P.Xdivy()
  363. def bprop(x, y, out, dout):
  364. x_dtype = F.dtype(x)
  365. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  366. bc_x = log_op(not_zero_x, y) * dout
  367. bc_y = div_op(x, y) * dout
  368. return binop_grad_common(x, y, bc_x, bc_y)
  369. return bprop
  370. @bprop_getters.register(P.SquareSumAll)
  371. def get_bprop_square_sum_all(self):
  372. """Grad definition for `SquareSumAll` operation."""
  373. mul_func = P.Mul()
  374. fill_func = P.Fill()
  375. dtype = P.DType()
  376. def bprop(x, y, out, dout):
  377. temp_x = mul_func(dout[0], x)
  378. temp_y = mul_func(dout[1], y)
  379. dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
  380. dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
  381. return (dx, dy)
  382. return bprop
  383. @bprop_getters.register(P.Sqrt)
  384. def get_bprop_sqrt(self):
  385. """Grad definition for `Sqrt` operation."""
  386. sqrt_grad = G.SqrtGrad()
  387. def bprop(x, out, dout):
  388. dx = sqrt_grad(out, dout)
  389. return (dx,)
  390. return bprop
  391. @bprop_getters.register(G.SqrtGrad)
  392. def get_bprop_sqrt_grad(self):
  393. """Grad definition for `SqrtGrad` operation."""
  394. def bprop(y, grad, out, dout):
  395. gy = dout / y
  396. dy = -gy * out
  397. dgrad = 0.5 * gy
  398. return dy, dgrad
  399. return bprop
  400. @bprop_getters.register(P.Rsqrt)
  401. def get_bprop_rsqrt(self):
  402. """Grad definition for `Rsqrt` operation."""
  403. rsqrt_grad = G.RsqrtGrad()
  404. def bprop(x, out, dout):
  405. dx = rsqrt_grad(out, dout)
  406. return (dx,)
  407. return bprop
  408. @bprop_getters.register(P.Reciprocal)
  409. def get_bprop_reciprocal(self):
  410. """Grad definition for `Reciprocal` operation."""
  411. reciprocal_grad = G.ReciprocalGrad()
  412. def bprop(x, out, dout):
  413. dx = reciprocal_grad(out, dout)
  414. return (dx,)
  415. return bprop
  416. @bprop_getters.register(P.Log)
  417. def get_bprop_log(self):
  418. """Grad definition for `Log` operation."""
  419. reciprocal = P.Reciprocal()
  420. def bprop(x, out, dout):
  421. g = reciprocal(x)
  422. dx = g * dout
  423. return (dx,)
  424. return bprop
  425. @bprop_getters.register(P.Log1p)
  426. def get_bprop_log1p(self):
  427. """Grad definition for `Log1p` operation."""
  428. reciprocal = P.Reciprocal()
  429. def bprop(x, out, dout):
  430. x_1p = x + 1
  431. g = reciprocal(x_1p)
  432. dx = g * dout
  433. return (dx,)
  434. return bprop
  435. @bprop_getters.register(P.Erf)
  436. def get_bprop_erf(self):
  437. """Grad definition for `Erf` operation."""
  438. exp = P.Exp()
  439. square = P.Square()
  440. sqrt = P.Sqrt()
  441. cast = P.Cast()
  442. dtype = P.DType()
  443. neg = P.Neg()
  444. def bprop(x, out, dout):
  445. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  446. x_square = square(x)
  447. dx = dout * half_root_pi * exp(neg(x_square))
  448. return (dx,)
  449. return bprop
  450. @bprop_getters.register(P.Erfc)
  451. def get_bprop_erfc(self):
  452. """Grad definition for `Erfc` operation."""
  453. exp = P.Exp()
  454. square = P.Square()
  455. sqrt = P.Sqrt()
  456. cast = P.Cast()
  457. dtype = P.DType()
  458. neg = P.Neg()
  459. def bprop(x, out, dout):
  460. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  461. x_square = square(x)
  462. dx = dout * (neg(half_root_pi) * exp(neg(x_square)))
  463. return (dx,)
  464. return bprop
  465. @bprop_getters.register(P.Pow)
  466. def get_bprop_pow(self):
  467. """Grad definition for `Pow` operation."""
  468. pow_op = P.Pow()
  469. ln = P.Log()
  470. def bprop(x, power, out, dout):
  471. bc_dx = power * pow_op(x, power - 1.0) * dout
  472. x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
  473. bc_dpower = out * ln(x) * dout
  474. return binop_grad_common(x, power, bc_dx, bc_dpower)
  475. return bprop
  476. @bprop_getters.register(P.Exp)
  477. def get_bprop_exp(self):
  478. """Grad definition for `Exp` operation."""
  479. exp_ = P.Exp()
  480. def bprop(x, out, dout):
  481. g = exp_(x)
  482. dx = g * dout
  483. return (dx,)
  484. return bprop
  485. @bprop_getters.register(P.Expm1)
  486. def get_bprop_expm1(self):
  487. """Grad definition for `Expm1` operation."""
  488. exp_ = P.Exp()
  489. def bprop(x, out, dout):
  490. g = exp_(x)
  491. dx = g * dout
  492. return (dx,)
  493. return bprop
  494. @bprop_getters.register(P.Minimum)
  495. def get_bprop_minimum(self):
  496. """Grad definition for `Minimum` operation."""
  497. input_grad = G.MinimumGrad()
  498. def bprop(x, y, out, dout):
  499. dx, dy = input_grad(x, y, dout)
  500. return dx, dy
  501. return bprop
  502. @bprop_getters.register(P.Maximum)
  503. def get_bprop_maximum(self):
  504. """Grad definition for `Maximum` operation."""
  505. input_grad = G.MaximumGrad()
  506. def bprop(x, y, out, dout):
  507. dx, dy = input_grad(x, y, dout)
  508. return dx, dy
  509. return bprop
  510. @bprop_getters.register(P.ReduceSum)
  511. def get_bprop_reducesum(self):
  512. """Grad definition for `ReduceSum` operation."""
  513. def bprop(x, axis, out, dout):
  514. dx = _sum_grad(x, axis, dout)
  515. return dx, zeros_like(axis)
  516. return bprop
  517. @bprop_getters.register(P.CumSum)
  518. def get_bprop_cumsum(self):
  519. """Grad definition for `CumSum` operation."""
  520. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  521. def bprop(x, axis, out, dout):
  522. return cumsum(dout, axis), zeros_like(axis)
  523. return bprop
  524. @constexpr
  525. def _split_shape_index(input_shape, axis):
  526. """Calculate reduce_prod grad transpose indices and perm shape."""
  527. rank = len(input_shape)
  528. if isinstance(axis, int):
  529. axis = tuple([axis])
  530. reduction_indices = tuple([(i + rank) % rank for i in axis])
  531. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  532. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  533. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  534. perm = reduction_indices + other_indices
  535. return tuple([reduced_num, other_num]), perm
  536. @constexpr
  537. def _invert_permutation(perm):
  538. """Calculate invert permutation."""
  539. out = [0] * len(perm)
  540. for i, value in enumerate(perm):
  541. out[value] = i
  542. return tuple(out)
  543. @bprop_getters.register(P.ReduceProd)
  544. def get_bprop_reduceprod(self):
  545. """Grad definition for `ReduceProd` operation."""
  546. transpose = P.Transpose()
  547. left_cumprod = P.CumProd(exclusive=True)
  548. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  549. def bprop(x, axis, out, dout):
  550. """Grad definition for `Product` operation."""
  551. # Expand dout to full input shape
  552. input_shape = shape_op(x)
  553. output_shape_kept_dims = reduced_shape(input_shape, axis)
  554. dout = reshape(dout, output_shape_kept_dims)
  555. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  556. grad = tile(dout, tile_scaling)
  557. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  558. pack_shape, perm = _split_shape_index(input_shape, axis)
  559. permuted = transpose(x, perm)
  560. permuted_shape = shape_op(permuted)
  561. reshaped = reshape(permuted, pack_shape)
  562. # Calculate product, leaving out the current entry
  563. left = left_cumprod(reshaped, 0)
  564. right = right_cumprod(reshaped, 0)
  565. y = reshape(left * right, permuted_shape)
  566. # Invert the transpose and reshape operations.
  567. # Make sure to set the statically known shape information through a reshape.
  568. out = transpose(y, _invert_permutation(perm)) * grad
  569. dx = reshape(out, input_shape)
  570. return dx, zeros_like(axis)
  571. return bprop
  572. @bprop_getters.register(P.CumProd)
  573. def get_bprop_cumprod(self):
  574. """Grad definition for `CumProd` operation."""
  575. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  576. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  577. def bprop(x, axis, out, dout):
  578. """Grad definition for `Product` operation."""
  579. # This will fails when x contains 0
  580. prod = cumprod(x, axis)
  581. out = cumsum(prod * dout, axis)
  582. return out / x, zeros_like(axis)
  583. return bprop
  584. @bprop_getters.register(P.ReduceAll)
  585. def get_bprop_reduceall(self):
  586. """Grad definition for `ReduceAll` operation."""
  587. def bprop(x, axis, out, dout):
  588. return zeros_like(x), zeros_like(axis)
  589. return bprop
  590. @bprop_getters.register(P.ReduceAny)
  591. def get_bprop_reduceany(self):
  592. """Grad definition for `ReduceAny` operation."""
  593. def bprop(x, axis, out, dout):
  594. return zeros_like(x), zeros_like(axis)
  595. return bprop
  596. @bprop_getters.register(P.ReduceMax)
  597. def get_bprop_reducemax(self):
  598. """Grad definition for `Max` operation."""
  599. def bprop(x, axis, out, dout):
  600. dx = _min_or_max_grad(x, axis, out, dout)
  601. return (dx, zeros_like(axis))
  602. return bprop
  603. @bprop_getters.register(P.ArgMaxWithValue)
  604. def get_bprop_argmaxwithvalue(self):
  605. """Grad definition for `ArgMaxWithValue` operation."""
  606. axis = self.axis
  607. keep_dims = self.keep_dims
  608. op = P.ArgMaxWithValue(axis)
  609. def bprop(x, out, dout):
  610. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  611. return (dx,)
  612. return bprop
  613. @bprop_getters.register(P.ReduceMin)
  614. def get_bprop_reducemin(self):
  615. """Grad definition for `ReduceMin` operation."""
  616. def bprop(x, axis, out, dout):
  617. dx = _min_or_max_grad(x, axis, out, dout)
  618. return (dx, zeros_like(axis))
  619. return bprop
  620. @bprop_getters.register(P.ArgMinWithValue)
  621. def get_bprop_argminwithvalue(self):
  622. """Generate bprop for ArgMinWithValue"""
  623. axis = self.axis
  624. keep_dims = self.keep_dims
  625. op = P.ArgMinWithValue(axis)
  626. def bprop(x, out, dout):
  627. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  628. return (dx,)
  629. return bprop
  630. @bprop_getters.register(P.ReduceMean)
  631. def get_bprop_reduce_mean(self):
  632. """Grad definition for `ReduceMean` operation."""
  633. div_op = P.RealDiv()
  634. cast = P.Cast()
  635. dtype = P.DType()
  636. def bprop(x, axis, out, dout):
  637. grad = _sum_grad(x, axis, dout)
  638. shape_x = shape_op(x)
  639. shape_out = shape_op(out)
  640. if -1 in shape_x:
  641. shape_x = dyn_shape_op(x)
  642. shape_out = dyn_shape_op(out)
  643. div_shape = reduce_prod(shape_x) / reduce_prod(shape_out)
  644. dx = div_op(grad, cast(div_shape, dtype(grad)))
  645. else:
  646. div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
  647. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  648. return dx, zeros_like(axis)
  649. return bprop
  650. @bprop_getters.register(P.IsFinite)
  651. def get_bprop_isfinite(self):
  652. """Grad definition for `IsFinite` operation."""
  653. def bprop(x, out, dout):
  654. return (zeros_like(x),)
  655. return bprop
  656. @bprop_getters.register(P.IsNan)
  657. def get_bprop_isnan(self):
  658. """Grad definition for `IsNan` operation."""
  659. def bprop(x, out, dout):
  660. return (zeros_like(x),)
  661. return bprop
  662. @bprop_getters.register(P.IsInf)
  663. def get_bprop_isinf(self):
  664. """Grad definition for `IsInf` operation."""
  665. def bprop(x, out, dout):
  666. return (zeros_like(x),)
  667. return bprop
  668. @bprop_getters.register(P.Equal)
  669. def get_bprop_equal(self):
  670. """Grad definition for `Equal` operation."""
  671. def bprop(x, y, out, dout):
  672. return zeros_like(x), zeros_like(y)
  673. return bprop
  674. @bprop_getters.register(P.NotEqual)
  675. def get_bprop_not_equal(self):
  676. """Grad definition for `NotEqual` operation."""
  677. def bprop(x, y, out, dout):
  678. return zeros_like(x), zeros_like(y)
  679. return bprop
  680. @bprop_getters.register(P.ApproximateEqual)
  681. def get_bprop_approximate_equal(self):
  682. """Grad definition for `ApproximateEqual` operation."""
  683. def bprop(x, y, out, dout):
  684. return zeros_like(x), zeros_like(y)
  685. return bprop
  686. @bprop_getters.register(P.Greater)
  687. def get_bprop_greater(self):
  688. """Grad definition for `Greater` operation."""
  689. def bprop(x, y, out, dout):
  690. return zeros_like(x), zeros_like(y)
  691. return bprop
  692. @bprop_getters.register(P.GreaterEqual)
  693. def get_bprop_greater_equal(self):
  694. """Grad definition for `GreaterEqual` operation."""
  695. def bprop(x, y, out, dout):
  696. return zeros_like(x), zeros_like(y)
  697. return bprop
  698. @bprop_getters.register(P.Less)
  699. def get_bprop_less(self):
  700. """Grad definition for `Less` operation."""
  701. def bprop(x, y, out, dout):
  702. return zeros_like(x), zeros_like(y)
  703. return bprop
  704. @bprop_getters.register(P.LessEqual)
  705. def get_bprop_less_equal(self):
  706. """Grad definition for `LessEqual` operation."""
  707. def bprop(x, y, out, dout):
  708. return zeros_like(x), zeros_like(y)
  709. return bprop
  710. @bprop_getters.register(P.LogicalNot)
  711. def get_bprop_logical_not(self):
  712. """Grad definition for `LogicalNot` operation."""
  713. def bprop(x, out, dout):
  714. return (zeros_like(x),)
  715. return bprop
  716. @bprop_getters.register(P.LogicalAnd)
  717. def get_bprop_logical_and(self):
  718. """Grad definition for `LogicalAnd` operation."""
  719. def bprop(x, y, out, dout):
  720. return zeros_like(x), zeros_like(y)
  721. return bprop
  722. @bprop_getters.register(P.LogicalOr)
  723. def get_bprop_logical_or(self):
  724. """Grad definition for `LogicalOr` operation."""
  725. def bprop(x, y, out, dout):
  726. return zeros_like(x), zeros_like(y)
  727. return bprop
  728. @bprop_getters.register(P.NPUAllocFloatStatus)
  729. def get_bprop_npu_alloc_float_status(self):
  730. """Grad definition for `NPUAllocFloatStatus` operation."""
  731. def bprop(out, dout):
  732. return ()
  733. return bprop
  734. @bprop_getters.register(P.NPUGetFloatStatus)
  735. def get_bprop_npu_get_float_status(self):
  736. """Grad definition for `NPUGetFloatStatus` operation."""
  737. def bprop(x, out, dout):
  738. return (zeros_like(x),)
  739. return bprop
  740. @bprop_getters.register(P.NPUClearFloatStatus)
  741. def get_bprop_npu_clear_float_status(self):
  742. """Grad definition for `NPUClearFloatStatus` operation."""
  743. def bprop(x, out, dout):
  744. return (zeros_like(x),)
  745. return bprop
  746. @bprop_getters.register(P.AssignAdd)
  747. def get_bprop_assign_add(self):
  748. """Grad definition for `AssignAdd` operation."""
  749. def bprop(x, y, out, dout):
  750. return zeros_like(x), zeros_like(y)
  751. return bprop
  752. @bprop_getters.register(P.AssignSub)
  753. def get_bprop_assign_sub(self):
  754. """Grad definition for `AssignSub` operation."""
  755. def bprop(x, y, out, dout):
  756. return zeros_like(x), zeros_like(y)
  757. return bprop
  758. @bprop_getters.register(P.Sin)
  759. def get_bprop_sin(self):
  760. """Grad definition for `Sin` operation."""
  761. cos = P.Cos()
  762. def bprop(x, out, dout):
  763. dx = dout * cos(x)
  764. return (dx,)
  765. return bprop
  766. @bprop_getters.register(P.Asin)
  767. def get_bprop_asin(self):
  768. """Grad definition for `Asin` operation."""
  769. input_grad = G.AsinGrad()
  770. def bprop(x, out, dout):
  771. dx = input_grad(x, dout)
  772. return (dx,)
  773. return bprop
  774. @bprop_getters.register(G.AsinGrad)
  775. def get_bprop_asin_grad(self):
  776. """Grad definition for `AsinGrad` operation."""
  777. input_grad = G.AsinGrad()
  778. p_pow = P.Pow()
  779. def bprop(x, grad, out, dout):
  780. d2x = dout * grad * x * p_pow((1 - x * x), - 1.5)
  781. ddy = input_grad(x, dout)
  782. return (d2x, ddy)
  783. return bprop
  784. @bprop_getters.register(P.Asinh)
  785. def get_bprop_asinh(self):
  786. """Grad definition for `Asinh` operation."""
  787. input_grad = G.AsinhGrad()
  788. def bprop(x, out, dout):
  789. dx = input_grad(out, dout)
  790. return (dx,)
  791. return bprop
  792. @bprop_getters.register(G.AsinhGrad)
  793. def get_bprop_asinh_grad(self):
  794. """Grad definition for `AsinhGrad` operation."""
  795. input_grad = G.AsinhGrad()
  796. tanh = P.Tanh()
  797. def bprop(y, grad, out, dout):
  798. dy = dout * out * -1.0 * tanh(y)
  799. dgrad = input_grad(y, dout)
  800. return dy, dgrad
  801. return bprop
  802. @bprop_getters.register(P.Sinh)
  803. def get_bprop_sinh(self):
  804. """Grad definition for `Sinh` operation."""
  805. cosh = P.Cosh()
  806. def bprop(x, out, dout):
  807. dx = cosh(x) * dout
  808. return (dx,)
  809. return bprop
  810. @bprop_getters.register(P.Cos)
  811. def get_bprop_cos(self):
  812. """Grad definition for `Cos` operation."""
  813. sin = P.Sin()
  814. neg = P.Neg()
  815. def bprop(x, out, dout):
  816. dx = dout * neg(sin(x))
  817. return (dx,)
  818. return bprop
  819. @bprop_getters.register(P.ACos)
  820. def get_bprop_acos(self):
  821. """Grad definition for `ACos` operation."""
  822. input_grad = G.ACosGrad()
  823. def bprop(x, out, dout):
  824. dx = input_grad(x, dout)
  825. return (dx,)
  826. return bprop
  827. @bprop_getters.register(G.ACosGrad)
  828. def get_bprop_acos_grad(self):
  829. """Grad definition for `ACosGrad` operation."""
  830. input_grad = G.ACosGrad()
  831. p_pow = P.Pow()
  832. def bprop(x, grad, out, dout):
  833. d2x = -dout * grad * x * p_pow((1 - x * x), - 1.5)
  834. ddy = input_grad(x, dout)
  835. return (d2x, ddy)
  836. return bprop
  837. @bprop_getters.register(P.Acosh)
  838. def get_bprop_acosh(self):
  839. """Grad definition for `Acosh` operation."""
  840. input_grad = G.AcoshGrad()
  841. def bprop(x, out, dout):
  842. dx = input_grad(out, dout)
  843. return (dx,)
  844. return bprop
  845. @bprop_getters.register(G.AcoshGrad)
  846. def get_bprop_acosh_grad(self):
  847. """Grad definition for `AcoshGrad` operation."""
  848. input_grad = G.AcoshGrad()
  849. tanh = P.Tanh()
  850. def bprop(y, grad, out, dout):
  851. dy = dout * out * -1.0 / tanh(y)
  852. dgrad = input_grad(y, dout)
  853. return dy, dgrad
  854. return bprop
  855. @bprop_getters.register(P.Cosh)
  856. def get_bprop_cosh(self):
  857. """Grad definition for `Cosh` operation."""
  858. sinh = P.Sinh()
  859. def bprop(x, out, dout):
  860. dx = sinh(x) * dout
  861. return (dx,)
  862. return bprop
  863. @bprop_getters.register(P.Abs)
  864. def get_bprop_abs(self):
  865. """Grad definition for `Abs` operation."""
  866. abs_grad = G.AbsGrad()
  867. def bprop(x, out, dout):
  868. dx = abs_grad(x, dout)
  869. return (dx,)
  870. return bprop
  871. @bprop_getters.register(P.Conj)
  872. def get_bprop_conj(self):
  873. """Grad definition for `Conj` operation."""
  874. conj = P.Conj()
  875. def bprop(x, out, dout):
  876. dx = conj(dout)
  877. return (dx,)
  878. return bprop
  879. @bprop_getters.register(P.Real)
  880. def get_bprop_real(self):
  881. """Grad definition for `Real` operation."""
  882. cast = P.Cast()
  883. dtype = P.DType()
  884. def bprop(x, out, dout):
  885. return (cast(dout, dtype(x)),)
  886. return bprop
  887. @bprop_getters.register(P.Imag)
  888. def get_bprop_imag(self):
  889. """Grad definition for `Imag` operation."""
  890. complex_op = P.Complex()
  891. def bprop(x, out, dout):
  892. zeros = zeros_like(dout)
  893. return (complex_op(zeros, zeros-1) * dout,)
  894. return bprop
  895. @bprop_getters.register(P.ScalarCast)
  896. def get_bprop_scalar_cast(self):
  897. """Generate bprop for ScalarCast"""
  898. def bprop(x, t, out, dout):
  899. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  900. return bprop
  901. @bprop_getters.register(P.AccumulateNV2)
  902. def get_bprop_scalar_accumulatenv2(self):
  903. """Generate bprop for AccumulateNV2"""
  904. def bprop(x, out, dout):
  905. dx = ()
  906. for _ in range(len(x)):
  907. dx = dx + (dout,)
  908. return (dx,)
  909. return bprop
  910. @bprop_getters.register(P.AddN)
  911. def get_bprop_scalar_addn(self):
  912. """Generate bprop for AddN"""
  913. def bprop(x, out, dout):
  914. if is_sub_class(F.typeof(x), ms.list_):
  915. dx = []
  916. for _ in range(len(x)):
  917. dx.append(dout)
  918. return (dx,)
  919. dx = ()
  920. for _ in range(len(x)):
  921. dx = dx + (dout,)
  922. return (dx,)
  923. return bprop
  924. @bprop_getters.register(P.Sign)
  925. def get_bprop_sign(self):
  926. """Generate bprop for Sign"""
  927. def bprop(x, out, dout):
  928. return (zeros_like(x),)
  929. return bprop
  930. @bprop_getters.register(P.Round)
  931. def get_bprop_round(self):
  932. """Generate bprop for Round"""
  933. def bprop(x, out, dout):
  934. return (zeros_like(x),)
  935. return bprop
  936. @bprop_getters.register(P.Atan2)
  937. def get_bprop_atan2(self):
  938. """Generate bprop for Atan2"""
  939. square = P.Square()
  940. def bprop(x, y, out, dout):
  941. tmp = dout / (square(x) + square(y))
  942. bc_dx = tmp * y
  943. bc_dy = tmp * (-x)
  944. return binop_grad_common(x, y, bc_dx, bc_dy)
  945. return bprop
  946. @bprop_getters.register(P.BesselI0e)
  947. def get_bprop_bessel_i0e(self):
  948. """Generate bprop for BesselI0e"""
  949. sign = P.Sign()
  950. bessel_i1e = P.BesselI1e()
  951. def bprop(x, out, dout):
  952. dx = dout * (bessel_i1e(x) - sign(x) * out)
  953. return (dx,)
  954. return bprop
  955. @bprop_getters.register(P.Atan)
  956. def get_bprop_atan(self):
  957. """Grad definition for `Atan` operation."""
  958. input_grad = G.AtanGrad()
  959. def bprop(x, out, dout):
  960. dx = input_grad(x, dout)
  961. return (dx,)
  962. return bprop
  963. @bprop_getters.register(G.AtanGrad)
  964. def get_bprop_atan_grad(self):
  965. """Grad definition for `AtanGrad` operation."""
  966. input_grad = G.AtanGrad()
  967. def bprop(x, grad, out, dout):
  968. dgrad = input_grad(x, dout)
  969. dx = out * dgrad * -2.0 * x
  970. return dx, dgrad
  971. return bprop
  972. @bprop_getters.register(P.Tan)
  973. def get_bprop_tan(self):
  974. """Grad definition for `Tan` operation."""
  975. reciprocal = P.Reciprocal()
  976. square = P.Square()
  977. cos = P.Cos()
  978. def bprop(x, out, dout):
  979. cosx = cos(x)
  980. secx2 = square(reciprocal(cosx))
  981. dx = secx2 * dout
  982. return (dx,)
  983. return bprop
  984. @bprop_getters.register(P.BesselI1e)
  985. def get_bprop_bessel_i1e(self):
  986. """Generate bprop for BesselI1e"""
  987. sign = P.Sign()
  988. bessel_i0e = P.BesselI0e()
  989. less = P.Less()
  990. select = P.Select()
  991. reciprocal = P.Reciprocal()
  992. cast = P.Cast()
  993. dtype = P.DType()
  994. abs_ops = P.Abs()
  995. def bprop(x, out, dout):
  996. zeros = zeros_like(x)
  997. np_eps = const_utils.get_np_eps(dtype(x))
  998. eps = cast(np_eps, dtype(x))
  999. x_is_valid = less(eps, abs_ops(x))
  1000. x_safe = select(x_is_valid, x, eps + zeros)
  1001. tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
  1002. dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
  1003. return (dx,)
  1004. return bprop
  1005. @bprop_getters.register(P.Atanh)
  1006. def get_bprop_atanh(self):
  1007. """Grad definition for `Atanh` operation."""
  1008. power = P.Pow()
  1009. div = P.Div()
  1010. def bprop(x, out, dout):
  1011. tmp = 1 - power(x, 2)
  1012. dx = div(1, tmp) * dout
  1013. return (dx,)
  1014. return bprop
  1015. @bprop_getters.register(P.Inv)
  1016. def get_bprop_inv(self):
  1017. """Grad definition for 'Inv' operation"""
  1018. inv_grad = G.InvGrad()
  1019. def bprop(x, out, dout):
  1020. dx = inv_grad(out, dout)
  1021. return (dx,)
  1022. return bprop
  1023. @bprop_getters.register(P.LinSpace)
  1024. def get_bprop_lin_space(self):
  1025. """Grad definition for `LinSpace` operation."""
  1026. def bprop(start, stop, num, out, dout):
  1027. return zeros_like(start), zeros_like(stop), zeros_like(num)
  1028. return bprop
  1029. @bprop_getters.register(P.IndexAdd)
  1030. def get_bprop_index_add(self):
  1031. """Generate bprop for IndexAdd"""
  1032. gather = P.Gather()
  1033. _axis = self.axis
  1034. def bprop(input_x, indices, input_y, out, dout):
  1035. return dout, zeros_like(indices), gather(dout, indices, _axis)
  1036. return bprop