You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 30 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. import mindspore as ms
  19. from mindspore.ops import _selected_grad_ops as SG
  20. from .. import functional as F
  21. from .. import operations as P
  22. from ..operations import _grad_ops as G
  23. from ..operations import _inner_ops as inner
  24. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  25. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  26. from .grad_base import bprop_getters
  27. from ..primitive import constexpr
  28. from ..composite.multitype_ops import _constexpr_utils as const_utils
  29. shape_op = P.Shape()
  30. reduce_sum = P.ReduceSum()
  31. reshape = P.Reshape()
  32. tile = P.Tile()
  33. is_sub_class = P.IsSubClass()
  34. def binop_grad_common(x, y, dx, dy):
  35. """
  36. Common grad definition for binary operations.
  37. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  38. """
  39. shape_of_x = shape_op(x)
  40. shape_of_y = shape_op(y)
  41. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  42. # if input shape is the same as dout shape, do not need to reduce
  43. reduce_dx = dx
  44. reduce_dy = dy
  45. if rx[0]:
  46. # if dx is scalar whose shape is (), do not need reduce
  47. if shape_op(dx):
  48. dx = reduce_sum(dx, rx[0])
  49. reduce_dx = reshape(dx, shape_of_x)
  50. if rx[1]:
  51. # if dy is scalar whose shape is (), do not need reduce
  52. if shape_op(dy):
  53. dy = reduce_sum(dy, rx[1])
  54. reduce_dy = reshape(dy, shape_of_y)
  55. return reduce_dx, reduce_dy
  56. def _sum_grad(x, axis, dout):
  57. """Grad definition for `Sum` operation."""
  58. input_shape = shape_op(x)
  59. output_shape_kept_dims = reduced_shape(input_shape, axis)
  60. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  61. grad = reshape(dout, output_shape_kept_dims)
  62. return tile(grad, tile_scaling)
  63. def _min_or_max_grad(x, axis, out, dout):
  64. """Grad definition for `Min` and `Max` operations."""
  65. input_shape = shape_op(x)
  66. output_shape_kept_dims = reduced_shape(input_shape, axis)
  67. y = reshape(out, output_shape_kept_dims)
  68. grad = reshape(dout, output_shape_kept_dims)
  69. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  70. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  71. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  72. return indicators / num_selected * grad
  73. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  74. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  75. expand = P.ExpandDims()
  76. x_shape = F.shape(x)
  77. x_dim = len(x_shape)
  78. x_axis = axis
  79. if x_axis < 0:
  80. x_axis = axis + x_dim
  81. onehot_axis = x_axis
  82. depth = x_shape[x_axis]
  83. if keep_dims:
  84. dout_expand = dout[1]
  85. out = op(x)
  86. else:
  87. dout_expand = expand(dout[1], onehot_axis)
  88. if onehot_axis >= len(shape_op(out[0])):
  89. onehot_axis = -1
  90. onehot = P.OneHot(onehot_axis)
  91. type_x = F.dtype(x)
  92. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  93. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  94. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  95. return dx
  96. @bprop_getters.register(P.MatMul)
  97. def bprop_matmul(self):
  98. """Grad definition for `MatMul` operation."""
  99. ta = self.transpose_a
  100. tb = self.transpose_b
  101. mul1 = P.MatMul(transpose_a=(ta and tb),
  102. transpose_b=(ta or (not tb)))
  103. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  104. transpose_b=(ta and tb))
  105. def bprop(x, w, out, dout):
  106. if ta:
  107. dx = mul1(w, dout)
  108. else:
  109. dx = mul1(dout, w)
  110. if tb:
  111. dw = mul2(dout, x)
  112. else:
  113. dw = mul2(x, dout)
  114. return dx, dw
  115. return bprop
  116. @bprop_getters.register(P.BatchMatMul)
  117. def bprop_batchmatmul(self):
  118. """Grad definition for `BatchMatMul` operation."""
  119. ta = self.transpose_a
  120. tb = self.transpose_b
  121. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  122. transpose_b=(ta or (not tb)))
  123. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  124. transpose_b=(ta and tb))
  125. def bprop(x, w, out, dout):
  126. if ta:
  127. dx = mul1(w, dout)
  128. else:
  129. dx = mul1(dout, w)
  130. if tb:
  131. dw = mul2(dout, x)
  132. else:
  133. dw = mul2(x, dout)
  134. return dx, dw
  135. return bprop
  136. @bprop_getters.register(P.TensorAdd)
  137. def get_bprop_tensor_add(self):
  138. """Grad definition for `TensorAdd` operation."""
  139. def bprop(x, y, out, dout):
  140. return binop_grad_common(x, y, dout, dout)
  141. return bprop
  142. @bprop_getters.register(P.Neg)
  143. def get_bprop_neg(self):
  144. """Grad definition for `Neg` operation."""
  145. neg_grad = P.Neg()
  146. def bprop(x, out, dout):
  147. dx = neg_grad(dout)
  148. return (dx,)
  149. return bprop
  150. @bprop_getters.register(P.Sub)
  151. def get_bprop_sub(self):
  152. """Grad definition for `Sub` operation."""
  153. neg_func = P.Neg()
  154. def bprop(x, y, out, dout):
  155. return binop_grad_common(x, y, dout, neg_func(dout))
  156. return bprop
  157. @bprop_getters.register(P.Mul)
  158. def get_bprop_mul(self):
  159. """Grad definition for `Mul` operation."""
  160. mul_func = P.Mul()
  161. def bprop(x, y, out, dout):
  162. bc_dx = mul_func(y, dout)
  163. bc_dy = mul_func(x, dout)
  164. return binop_grad_common(x, y, bc_dx, bc_dy)
  165. return bprop
  166. @bprop_getters.register(P.RealDiv)
  167. def get_bprop_real_div(self):
  168. """Grad definition for `RealDiv` operation."""
  169. div_op = P.RealDiv()
  170. neg = P.Neg()
  171. mul_op = P.Mul()
  172. def bprop(x, y, out, dout):
  173. bc_x = div_op(dout, y)
  174. bc_y = neg(mul_op(bc_x, out))
  175. return binop_grad_common(x, y, bc_x, bc_y)
  176. return bprop
  177. @bprop_getters.register(P.Div)
  178. def get_bprop_div(self):
  179. """Grad definition for `Div` operation."""
  180. div_op = P.Div()
  181. neg = P.Neg()
  182. mul_op = P.Mul()
  183. def bprop(x, y, out, dout):
  184. bc_x = div_op(dout, y)
  185. bc_y = neg(mul_op(bc_x, out))
  186. return binop_grad_common(x, y, bc_x, bc_y)
  187. return bprop
  188. @bprop_getters.register(P.DivNoNan)
  189. def get_bprop_div_no_nan(self):
  190. """Grad definition for `DivNoNan` operation."""
  191. div_no_nan_op = P.DivNoNan()
  192. neg = P.Neg()
  193. mul_op = P.Mul()
  194. def bprop(x, y, out, dout):
  195. bc_x = div_no_nan_op(dout, y)
  196. bc_y = neg(mul_op(bc_x, out))
  197. return binop_grad_common(x, y, bc_x, bc_y)
  198. return bprop
  199. @bprop_getters.register(P.Xdivy)
  200. def get_bprop_xdivy(self):
  201. """Grad definition for `Xdivy` operation."""
  202. div_op = P.Xdivy()
  203. def bprop(x, y, out, dout):
  204. x_dtype = F.dtype(x)
  205. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  206. bc_x = div_op(not_zero_x, y) * dout
  207. bc_y = div_op(-x, F.square(y)) * dout
  208. return binop_grad_common(x, y, bc_x, bc_y)
  209. return bprop
  210. @bprop_getters.register(P.Floor)
  211. def get_bprop_floor(self):
  212. """Grad definition for `floor` operation."""
  213. fill_ = P.Fill()
  214. shape_ = P.Shape()
  215. dtype_ = P.DType()
  216. def bprop(x, out, dout):
  217. bc_x = fill_(dtype_(x), shape_(x), 0.)
  218. return (bc_x,)
  219. return bprop
  220. @bprop_getters.register(P.Ceil)
  221. def get_bprop_ceil(self):
  222. """Grad definition for `ceil` operation."""
  223. fill_ = P.Fill()
  224. shape_ = P.Shape()
  225. dtype_ = P.DType()
  226. def bprop(x, out, dout):
  227. bc_x = fill_(dtype_(x), shape_(x), 0.)
  228. return (bc_x,)
  229. return bprop
  230. @bprop_getters.register(P.FloorDiv)
  231. def get_bprop_floordiv(self):
  232. """Grad definition for `FloorDiv` operation."""
  233. def bprop(x, y, out, dout):
  234. return zeros_like(x), zeros_like(y)
  235. return bprop
  236. @bprop_getters.register(P.FloorMod)
  237. def get_bprop_floormod(self):
  238. """Grad definition for `FloorMod` operation."""
  239. def bprop(x, y, out, dout):
  240. bc_x = dout
  241. bc_y = -dout * (x // y)
  242. return binop_grad_common(x, y, bc_x, bc_y)
  243. return bprop
  244. @bprop_getters.register(P.TruncateDiv)
  245. def get_bprop_truncate_div(self):
  246. """Grad definition for `TruncateDiv` operation."""
  247. def bprop(x, y, out, dout):
  248. return zeros_like(x), zeros_like(y)
  249. return bprop
  250. @bprop_getters.register(P.TruncateMod)
  251. def get_bprop_truncate_mod(self):
  252. """Grad definition for `TruncateMod` operation."""
  253. div_op = P.TruncateDiv()
  254. def bprop(x, y, out, dout):
  255. bc_x = dout
  256. bc_y = -dout * div_op(x, y)
  257. return binop_grad_common(x, y, bc_x, bc_y)
  258. return bprop
  259. @bprop_getters.register(P.Mod)
  260. def get_bprop_mod(self):
  261. """Grad definition for `Mod` operation."""
  262. def bprop(x, y, out, dout):
  263. bc_x = dout
  264. bc_y = -dout * (x // y)
  265. return binop_grad_common(x, y, bc_x, bc_y)
  266. return bprop
  267. @bprop_getters.register(P.Square)
  268. def get_bprop_square(self):
  269. """Grad definition for `Square` operation."""
  270. mul_func = P.Mul()
  271. fill_func = P.Fill()
  272. dtype = P.DType()
  273. def bprop(x, out, dout):
  274. temp = mul_func(dout, x)
  275. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  276. return (dx,)
  277. return bprop
  278. @bprop_getters.register(P.SquaredDifference)
  279. def get_bprop_squared_difference(self):
  280. """Grad definition for `SquaredDifference` operation."""
  281. neg = P.Neg()
  282. def bprop(x, y, out, dout):
  283. x_grad = 2 * dout * (x - y)
  284. bc_x = x_grad
  285. bc_y = neg(x_grad)
  286. return binop_grad_common(x, y, bc_x, bc_y)
  287. return bprop
  288. @bprop_getters.register(P.Xlogy)
  289. def get_bprop_xlogy(self):
  290. """Grad definition for `Xlogy` operation."""
  291. log_op = P.Xlogy()
  292. div_op = P.Xdivy()
  293. def bprop(x, y, out, dout):
  294. x_dtype = F.dtype(x)
  295. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  296. bc_x = log_op(not_zero_x, y) * dout
  297. bc_y = div_op(x, y) * dout
  298. return binop_grad_common(x, y, bc_x, bc_y)
  299. return bprop
  300. @bprop_getters.register(P.SquareSumAll)
  301. def get_bprop_square_sum_all(self):
  302. """Grad definition for `Square` operation."""
  303. mul_func = P.Mul()
  304. fill_func = P.Fill()
  305. dtype = P.DType()
  306. def bprop(x, y, out, dout):
  307. temp_x = mul_func(dout[0], x)
  308. temp_y = mul_func(dout[1], y)
  309. dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
  310. dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
  311. return (dx, dy)
  312. return bprop
  313. @bprop_getters.register(P.Sqrt)
  314. def get_bprop_sqrt(self):
  315. """Grad definition for `Sqrt` operation."""
  316. sqrt_grad = G.SqrtGrad()
  317. def bprop(x, out, dout):
  318. dx = sqrt_grad(out, dout)
  319. return (dx,)
  320. return bprop
  321. @bprop_getters.register(P.Rsqrt)
  322. def get_bprop_rsqrt(self):
  323. """Grad definition for `Rsqrt` operation."""
  324. rsqrt_grad = G.RsqrtGrad()
  325. def bprop(x, out, dout):
  326. dx = rsqrt_grad(out, dout)
  327. return (dx,)
  328. return bprop
  329. @bprop_getters.register(P.Reciprocal)
  330. def get_bprop_reciprocal(self):
  331. """Grad definition for `Reciprocal` operation."""
  332. if self.target == "GPU":
  333. neg = P.Neg()
  334. mul = P.Mul()
  335. square = P.Square()
  336. reciprocal = P.Reciprocal()
  337. def bprop(x, out, dout):
  338. g = neg(reciprocal(square(x)))
  339. dx = mul(dout, g)
  340. return (dx,)
  341. else:
  342. reciprocal_grad = G.ReciprocalGrad()
  343. def bprop(x, out, dout):
  344. dx = reciprocal_grad(out, dout)
  345. return (dx,)
  346. return bprop
  347. @bprop_getters.register(P.Log)
  348. def get_bprop_log(self):
  349. """Grad definition for `Log` operation."""
  350. reciprocal = P.Reciprocal()
  351. def bprop(x, out, dout):
  352. g = reciprocal(x)
  353. dx = g * dout
  354. return dx, 0
  355. return bprop
  356. @bprop_getters.register(P.Log1p)
  357. def get_bprop_log1p(self):
  358. """Grad definition for `Log1p` operation."""
  359. reciprocal = P.Reciprocal()
  360. def bprop(x, out, dout):
  361. x_1p = x + 1
  362. g = reciprocal(x_1p)
  363. dx = g * dout
  364. return dx, 0
  365. return bprop
  366. @bprop_getters.register(P.Erf)
  367. def get_bprop_erf(self):
  368. """Grad definition for `Erf` operation."""
  369. exp = P.Exp()
  370. square = P.Square()
  371. sqrt = P.Sqrt()
  372. cast = P.Cast()
  373. dtype = P.DType()
  374. def bprop(x, out, dout):
  375. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  376. x_square = square(x)
  377. dx = dout * half_root_pi * exp(-x_square)
  378. return (dx,)
  379. return bprop
  380. @bprop_getters.register(P.Erfc)
  381. def get_bprop_erfc(self):
  382. """Grad definition for `Erfc` operation."""
  383. exp = P.Exp()
  384. square = P.Square()
  385. sqrt = P.Sqrt()
  386. cast = P.Cast()
  387. dtype = P.DType()
  388. def bprop(x, out, dout):
  389. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  390. x_square = square(x)
  391. dx = dout * (-half_root_pi * exp(-x_square))
  392. return (dx,)
  393. return bprop
  394. @bprop_getters.register(P.Pow)
  395. def get_bprop_pow(self):
  396. """Grad definition for `Pow` operation."""
  397. pow_op = P.Pow()
  398. ln = P.Log()
  399. def bprop(x, power, out, dout):
  400. bc_dx = power * pow_op(x, power - 1.0) * dout
  401. x[x < 0] = 1
  402. bc_dpower = out * ln(x) * dout
  403. return binop_grad_common(x, power, bc_dx, bc_dpower)
  404. return bprop
  405. @bprop_getters.register(P.Exp)
  406. def get_bprop_exp(self):
  407. """Grad definition for `Exp` operation."""
  408. exp_ = P.Exp()
  409. def bprop(x, out, dout):
  410. g = exp_(x)
  411. dx = g * dout
  412. return (dx,)
  413. return bprop
  414. @bprop_getters.register(P.Expm1)
  415. def get_bprop_expm1(self):
  416. """Grad definition for `Expm1` operation."""
  417. exp_ = P.Exp()
  418. def bprop(x, out, dout):
  419. g = exp_(x)
  420. dx = g * dout
  421. return (dx,)
  422. return bprop
  423. @bprop_getters.register(P.Minimum)
  424. def get_bprop_minimum(self):
  425. """Grad definition for `Minimum` operation."""
  426. input_grad = SG.MinimumGrad()
  427. def bprop(x, y, out, dout):
  428. dx, dy = input_grad(x, y, dout)
  429. return dx, dy
  430. return bprop
  431. @bprop_getters.register(P.Maximum)
  432. def get_bprop_maximum(self):
  433. """Grad definition for `Maximum` operation."""
  434. input_grad = SG.MaximumGrad()
  435. def bprop(x, y, out, dout):
  436. dx, dy = input_grad(x, y, dout)
  437. return dx, dy
  438. return bprop
  439. @bprop_getters.register(P.ReduceSum)
  440. def get_bprop_reducesum(self):
  441. """Grad definition for `ReduceSum` operation."""
  442. def bprop(x, axis, out, dout):
  443. dx = _sum_grad(x, axis, dout)
  444. return dx, zeros_like(axis)
  445. return bprop
  446. @bprop_getters.register(P.CumSum)
  447. def get_bprop_cumsum(self):
  448. """Grad definition for `CumSum` operation."""
  449. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  450. def bprop(x, axis, out, dout):
  451. return cumsum(dout, axis), zeros_like(axis)
  452. return bprop
  453. @constexpr
  454. def _split_shape_index(input_shape, axis):
  455. """Calculate reduce_prod grad transpose indices and perm shape."""
  456. rank = len(input_shape)
  457. if isinstance(axis, int):
  458. axis = tuple([axis])
  459. reduction_indices = tuple([(i + rank) % rank for i in axis])
  460. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  461. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  462. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  463. perm = reduction_indices + other_indices
  464. return tuple([reduced_num, other_num]), perm
  465. @constexpr
  466. def _invert_permutation(perm):
  467. """Calculate invert permutation."""
  468. out = [0] * len(perm)
  469. for i, value in enumerate(perm):
  470. out[value] = i
  471. return tuple(out)
  472. @bprop_getters.register(P.ReduceProd)
  473. def get_bprop_reduceprod(self):
  474. """Grad definition for `ReduceProd` operation."""
  475. transpose = P.Transpose()
  476. left_cumprod = P.CumProd(exclusive=True)
  477. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  478. def bprop(x, axis, out, dout):
  479. """Grad definition for `Product` operation."""
  480. # Expand dout to full input shape
  481. input_shape = shape_op(x)
  482. output_shape_kept_dims = reduced_shape(input_shape, axis)
  483. dout = reshape(dout, output_shape_kept_dims)
  484. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  485. grad = tile(dout, tile_scaling)
  486. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  487. pack_shape, perm = _split_shape_index(input_shape, axis)
  488. permuted = transpose(x, perm)
  489. permuted_shape = shape_op(permuted)
  490. reshaped = reshape(permuted, pack_shape)
  491. # Calculate product, leaving out the current entry
  492. left = left_cumprod(reshaped, 0)
  493. right = right_cumprod(reshaped, 0)
  494. y = reshape(left * right, permuted_shape)
  495. # Invert the transpose and reshape operations.
  496. # Make sure to set the statically known shape information through a reshape.
  497. out = transpose(y, _invert_permutation(perm)) * grad
  498. dx = reshape(out, input_shape)
  499. return dx, zeros_like(axis)
  500. return bprop
  501. @bprop_getters.register(P.CumProd)
  502. def get_bprop_cumprod(self):
  503. """Grad definition for `CumProd` operation."""
  504. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  505. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  506. def bprop(x, axis, out, dout):
  507. """Grad definition for `Product` operation."""
  508. # This will fails when x contains 0
  509. prod = cumprod(x, axis)
  510. out = cumsum(prod * dout, axis)
  511. return out / x, zeros_like(axis)
  512. return bprop
  513. @bprop_getters.register(P.ReduceAll)
  514. def get_bprop_reduceall(self):
  515. """Grad definition for `ReduceAll` operation."""
  516. def bprop(x, axis, out, dout):
  517. return zeros_like(x), zeros_like(axis)
  518. return bprop
  519. @bprop_getters.register(P.ReduceAny)
  520. def get_bprop_reduceany(self):
  521. """Grad definition for `ReduceAny` operation."""
  522. def bprop(x, axis, out, dout):
  523. return zeros_like(x), zeros_like(axis)
  524. return bprop
  525. @bprop_getters.register(P.ReduceMax)
  526. def get_bprop_reducemax(self):
  527. """Grad definition for `Max` operation."""
  528. def bprop(x, axis, out, dout):
  529. dx = _min_or_max_grad(x, axis, out, dout)
  530. return (dx, zeros_like(axis))
  531. return bprop
  532. @bprop_getters.register(P.ArgMaxWithValue)
  533. def get_bprop_argmaxwithvalue(self):
  534. """Grad definition for `ArgMaxWithValue` operation."""
  535. axis = self.axis
  536. keep_dims = self.keep_dims
  537. op = P.ArgMaxWithValue(axis)
  538. def bprop(x, out, dout):
  539. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  540. return (dx,)
  541. return bprop
  542. @bprop_getters.register(P.ReduceMin)
  543. def get_bprop_reducemin(self):
  544. """Grad definition for `ReduceMin` operation."""
  545. def bprop(x, axis, out, dout):
  546. dx = _min_or_max_grad(x, axis, out, dout)
  547. return (dx, zeros_like(axis))
  548. return bprop
  549. @bprop_getters.register(P.ArgMinWithValue)
  550. def get_bprop_argminwithvalue(self):
  551. """Generate bprop for ArgMinWithValue"""
  552. axis = self.axis
  553. keep_dims = self.keep_dims
  554. op = P.ArgMinWithValue(axis)
  555. def bprop(x, out, dout):
  556. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  557. return (dx,)
  558. return bprop
  559. @bprop_getters.register(P.ReduceMean)
  560. def get_bprop_reduce_mean(self):
  561. """Grad definition for `ReduceMean` operation."""
  562. div_op = P.RealDiv()
  563. cast = P.Cast()
  564. dtype = P.DType()
  565. def bprop(x, axis, out, dout):
  566. grad = _sum_grad(x, axis, dout)
  567. div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out))
  568. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  569. return dx, zeros_like(axis)
  570. return bprop
  571. @bprop_getters.register(P.Equal)
  572. def get_bprop_equal(self):
  573. """Grad definition for `Equal` operation."""
  574. def bprop(x, y, out, dout):
  575. return zeros_like(x), zeros_like(y)
  576. return bprop
  577. @bprop_getters.register(P.NotEqual)
  578. def get_bprop_not_equal(self):
  579. """Grad definition for `NotEqual` operation."""
  580. def bprop(x, y, out, dout):
  581. return zeros_like(x), zeros_like(y)
  582. return bprop
  583. @bprop_getters.register(P.ApproximateEqual)
  584. def get_bprop_approximate_equal(self):
  585. """Grad definition for `ApproximateEqual` operation."""
  586. def bprop(x, y, out, dout):
  587. return zeros_like(x), zeros_like(y)
  588. return bprop
  589. @bprop_getters.register(P.Greater)
  590. def get_bprop_greater(self):
  591. """Grad definition for `Greater` operation."""
  592. def bprop(x, y, out, dout):
  593. return zeros_like(x), zeros_like(y)
  594. return bprop
  595. @bprop_getters.register(P.GreaterEqual)
  596. def get_bprop_greater_equal(self):
  597. """Grad definition for `GreaterEqual` operation."""
  598. def bprop(x, y, out, dout):
  599. return zeros_like(x), zeros_like(y)
  600. return bprop
  601. @bprop_getters.register(P.Less)
  602. def get_bprop_less(self):
  603. """Grad definition for `Less` operation."""
  604. def bprop(x, y, out, dout):
  605. return zeros_like(x), zeros_like(y)
  606. return bprop
  607. @bprop_getters.register(P.LessEqual)
  608. def get_bprop_less_equal(self):
  609. """Grad definition for `LessEqual` operation."""
  610. def bprop(x, y, out, dout):
  611. return zeros_like(x), zeros_like(y)
  612. return bprop
  613. @bprop_getters.register(P.LogicalNot)
  614. def get_bprop_logical_not(self):
  615. """Grad definition for `LogicalNot` operation."""
  616. def bprop(x, out, dout):
  617. return (zeros_like(x),)
  618. return bprop
  619. @bprop_getters.register(P.LogicalAnd)
  620. def get_bprop_logical_and(self):
  621. """Grad definition for `LogicalAnd` operation."""
  622. def bprop(x, y, out, dout):
  623. return zeros_like(x), zeros_like(y)
  624. return bprop
  625. @bprop_getters.register(P.LogicalOr)
  626. def get_bprop_logical_or(self):
  627. """Grad definition for `LogicalOr` operation."""
  628. def bprop(x, y, out, dout):
  629. return zeros_like(x), zeros_like(y)
  630. return bprop
  631. @bprop_getters.register(P.NPUAllocFloatStatus)
  632. def get_bprop_npu_alloc_float_status(self):
  633. """Grad definition for `NPUAllocFloatStatus` operation."""
  634. def bprop(out, dout):
  635. return ()
  636. return bprop
  637. @bprop_getters.register(P.NPUGetFloatStatus)
  638. def get_bprop_npu_get_float_status(self):
  639. """Grad definition for `NPUGetFloatStatus` operation."""
  640. def bprop(x, out, dout):
  641. return (zeros_like(x),)
  642. return bprop
  643. @bprop_getters.register(P.NPUClearFloatStatus)
  644. def get_bprop_npu_clear_float_status(self):
  645. """Grad definition for `NPUClearFloatStatus` operation."""
  646. def bprop(x, out, dout):
  647. return (zeros_like(x),)
  648. return bprop
  649. @bprop_getters.register(P.AssignAdd)
  650. def get_bprop_assign_add(self):
  651. """Grad definition for `AssignAdd` operation."""
  652. def bprop(x, y, out, dout):
  653. return zeros_like(x), zeros_like(y)
  654. return bprop
  655. @bprop_getters.register(P.AssignSub)
  656. def get_bprop_assign_sub(self):
  657. """Grad definition for `AssignSub` operation."""
  658. def bprop(x, y, out, dout):
  659. return zeros_like(x), zeros_like(y)
  660. return bprop
  661. @bprop_getters.register(P.Sin)
  662. def get_bprop_sin(self):
  663. """Grad definition for `Sin` operation."""
  664. cos = P.Cos()
  665. def bprop(x, out, dout):
  666. dx = dout * cos(x)
  667. return (dx,)
  668. return bprop
  669. @bprop_getters.register(P.Asin)
  670. def get_bprop_asin(self):
  671. """Grad definition for `Asin` operation."""
  672. input_grad = G.AsinGrad()
  673. def bprop(x, out, dout):
  674. dx = input_grad(x, dout)
  675. return (dx,)
  676. return bprop
  677. @bprop_getters.register(P.Asinh)
  678. def get_bprop_asinh(self):
  679. """Grad definition for `Asinh` operation."""
  680. input_grad = G.AsinhGrad()
  681. def bprop(x, out, dout):
  682. dx = input_grad(out, dout)
  683. return (dx,)
  684. return bprop
  685. @bprop_getters.register(P.Sinh)
  686. def get_bprop_sinh(self):
  687. """Grad definition for `Sinh` operation."""
  688. cosh = P.Cosh()
  689. def bprop(x, out, dout):
  690. dx = cosh(x) * dout
  691. return (dx,)
  692. return bprop
  693. @bprop_getters.register(P.Cos)
  694. def get_bprop_cos(self):
  695. """Grad definition for `Cos` operation."""
  696. sin = P.Sin()
  697. neg = P.Neg()
  698. def bprop(x, out, dout):
  699. dx = dout * neg(sin(x))
  700. return (dx,)
  701. return bprop
  702. @bprop_getters.register(P.ACos)
  703. def get_bprop_acos(self):
  704. """Grad definition for `ACos` operation."""
  705. input_grad = G.ACosGrad()
  706. def bprop(x, out, dout):
  707. dx = input_grad(x, dout)
  708. return (dx,)
  709. return bprop
  710. @bprop_getters.register(P.Acosh)
  711. def get_bprop_acosh(self):
  712. """Grad definition for `Acosh` operation."""
  713. input_grad = G.AcoshGrad()
  714. def bprop(x, out, dout):
  715. dx = input_grad(out, dout)
  716. return (dx,)
  717. return bprop
  718. @bprop_getters.register(P.Cosh)
  719. def get_bprop_cosh(self):
  720. """Grad definition for `Cosh` operation."""
  721. sinh = P.Sinh()
  722. def bprop(x, out, dout):
  723. dx = sinh(x) * dout
  724. return (dx,)
  725. return bprop
  726. @bprop_getters.register(P.Abs)
  727. def get_bprop_abs(self):
  728. """Grad definition for `Abs` operation."""
  729. abs_grad = SG.AbsGrad()
  730. def bprop(x, out, dout):
  731. dx = abs_grad(x, dout)
  732. return (dx,)
  733. return bprop
  734. @bprop_getters.register(P.ScalarCast)
  735. def get_bprop_scalar_cast(self):
  736. """Generate bprop for ScalarCast"""
  737. def bprop(x, t, out, dout):
  738. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  739. return bprop
  740. @bprop_getters.register(P.AccumulateNV2)
  741. def get_bprop_scalar_accumulatenv2(self):
  742. """Generate bprop for AccumulateNV2"""
  743. def bprop(x, out, dout):
  744. dx = ()
  745. for _ in range(len(x)):
  746. dx = dx + (dout,)
  747. return (dx,)
  748. return bprop
  749. @bprop_getters.register(P.AddN)
  750. def get_bprop_scalar_addn(self):
  751. """Generate bprop for AddN"""
  752. def bprop(x, out, dout):
  753. if is_sub_class(F.typeof(x), ms.list_):
  754. dx = []
  755. for _ in range(len(x)):
  756. dx.append(dout)
  757. return (dx,)
  758. dx = ()
  759. for _ in range(len(x)):
  760. dx = dx + (dout,)
  761. return (dx,)
  762. return bprop
  763. @bprop_getters.register(P.Sign)
  764. def get_bprop_sign(self):
  765. """Generate bprop for Sign"""
  766. def bprop(x, out, dout):
  767. return (zeros_like(x),)
  768. return bprop
  769. @bprop_getters.register(P.Round)
  770. def get_bprop_round(self):
  771. """Generate bprop for Round"""
  772. def bprop(x, out, dout):
  773. return (zeros_like(x),)
  774. return bprop
  775. @bprop_getters.register(P.Atan2)
  776. def get_bprop_atan2(self):
  777. """Generate bprop for Atan2"""
  778. square = P.Square()
  779. def bprop(x, y, out, dout):
  780. tmp = dout / (square(x) + square(y))
  781. bc_dx = tmp * y
  782. bc_dy = tmp * (-x)
  783. return binop_grad_common(x, y, bc_dx, bc_dy)
  784. return bprop
  785. @bprop_getters.register(P.BesselI0e)
  786. def get_bprop_bessel_i0e(self):
  787. """Generate bprop for BesselI0e"""
  788. sign = P.Sign()
  789. bessel_i1e = P.BesselI1e()
  790. def bprop(x, out, dout):
  791. dx = dout * (bessel_i1e(x) - sign(x) * out)
  792. return (dx,)
  793. return bprop
  794. @bprop_getters.register(P.Atan)
  795. def get_bprop_atan(self):
  796. """Grad definition for `Atan` operation."""
  797. input_grad = G.AtanGrad()
  798. def bprop(x, out, dout):
  799. dx = input_grad(x, dout)
  800. return (dx,)
  801. return bprop
  802. @bprop_getters.register(P.Tan)
  803. def get_bprop_tan(self):
  804. """Grad definition for `Tan` operation."""
  805. reciprocal = P.Reciprocal()
  806. square = P.Square()
  807. cos = P.Cos()
  808. def bprop(x, out, dout):
  809. cosx = cos(x)
  810. secx2 = square(reciprocal(cosx))
  811. dx = secx2 * dout
  812. return (dx,)
  813. return bprop
  814. @bprop_getters.register(P.BesselI1e)
  815. def get_bprop_bessel_i1e(self):
  816. """Generate bprop for BesselI1e"""
  817. sign = P.Sign()
  818. bessel_i0e = P.BesselI0e()
  819. less = P.Less()
  820. select = P.Select()
  821. reciprocal = P.Reciprocal()
  822. cast = P.Cast()
  823. dtype = P.DType()
  824. abs_ops = P.Abs()
  825. def bprop(x, out, dout):
  826. zeros = zeros_like(x)
  827. np_eps = const_utils.get_np_eps(dtype(x))
  828. eps = cast(np_eps, dtype(x))
  829. x_is_valid = less(eps, abs_ops(x))
  830. x_safe = select(x_is_valid, x, eps + zeros)
  831. tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
  832. dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
  833. return (dx,)
  834. return bprop
  835. @bprop_getters.register(P.Atanh)
  836. def get_bprop_atanh(self):
  837. """Grad definition for `Atanh` operation."""
  838. power = P.Pow()
  839. div = P.Div()
  840. def bprop(x, out, dout):
  841. tmp = 1 - power(x, 2)
  842. dx = div(1, tmp) * dout
  843. return (dx,)
  844. return bprop
  845. @bprop_getters.register(P.Inv)
  846. def get_bprop_inv(self):
  847. """Grad definition for 'Inv' operation"""
  848. inv_grad = G.InvGrad()
  849. def bprop(x, out, dout):
  850. dx = inv_grad(out, dout)
  851. return (dx,)
  852. return bprop
  853. @bprop_getters.register(inner.LinSpace)
  854. def get_bprop_lin_space(self):
  855. """Grad definition for `LinSpace` operation."""
  856. def bprop(assist, start, stop, num, out, dout):
  857. return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num)
  858. return bprop