You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 27 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. from mindspore.ops import _selected_grad_ops as SG
  19. from .. import functional as F
  20. from .. import operations as P
  21. from ..operations import _grad_ops as G
  22. from ..operations import _inner_ops as inner
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  25. from .grad_base import bprop_getters
  26. from ..primitive import constexpr
  27. from ..composite.multitype_ops import _constexpr_utils as const_utils
  28. shape_op = P.Shape()
  29. reduce_sum = P.ReduceSum()
  30. reshape = P.Reshape()
  31. tile = P.Tile()
  32. def binop_grad_common(x, y, dx, dy):
  33. """
  34. Common grad definition for binary operations.
  35. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  36. """
  37. shape_of_x = shape_op(x)
  38. shape_of_y = shape_op(y)
  39. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  40. # if input shape is the same as dout shape, do not need to reduce
  41. reduce_dx = dx
  42. reduce_dy = dy
  43. if rx[0]:
  44. # if dx is scalar whose shape is (), do not need reduce
  45. if shape_op(dx):
  46. dx = reduce_sum(dx, rx[0])
  47. reduce_dx = reshape(dx, shape_of_x)
  48. if rx[1]:
  49. # if dy is scalar whose shape is (), do not need reduce
  50. if shape_op(dy):
  51. dy = reduce_sum(dy, rx[1])
  52. reduce_dy = reshape(dy, shape_of_y)
  53. return reduce_dx, reduce_dy
  54. def _sum_grad(x, axis, dout):
  55. """Grad definition for `Sum` operation."""
  56. # input_shape = [2, 3] axis = [1]
  57. input_shape = shape_op(x)
  58. # output_shape_kept_dims = [2, 1]
  59. output_shape_kept_dims = reduced_shape(input_shape, axis)
  60. # tile_scaling = [1, 3]
  61. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  62. grad = reshape(dout, output_shape_kept_dims)
  63. return tile(grad, tile_scaling)
  64. def _min_or_max_grad(x, axis, out, dout):
  65. """Grad definition for `Min` and `Max` operations."""
  66. # input_shape = [2, 3] axis = [1]
  67. input_shape = shape_op(x)
  68. # output_shape_kept_dims = [2, 1]
  69. output_shape_kept_dims = reduced_shape(input_shape, axis)
  70. y = reshape(out, output_shape_kept_dims)
  71. grad = reshape(dout, output_shape_kept_dims)
  72. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  73. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  74. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  75. return indicators / num_selected * grad
  76. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  77. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  78. expand = P.ExpandDims()
  79. x_shape = F.shape(x)
  80. x_dim = len(x_shape)
  81. x_axis = axis
  82. if x_axis < 0:
  83. x_axis = axis + x_dim
  84. onehot_axis = x_axis
  85. depth = x_shape[x_axis]
  86. if keep_dims:
  87. dout_expand = dout[1]
  88. out = op(x)
  89. else:
  90. dout_expand = expand(dout[1], onehot_axis)
  91. if onehot_axis >= len(shape_op(out[0])):
  92. onehot_axis = -1
  93. onehot = P.OneHot(onehot_axis)
  94. type_x = F.dtype(x)
  95. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  96. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  97. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  98. return dx
  99. @bprop_getters.register(P.MatMul)
  100. def bprop_matmul(self):
  101. """Grad definition for `MatMul` operation."""
  102. ta = self.transpose_a
  103. tb = self.transpose_b
  104. mul1 = P.MatMul(transpose_a=(ta and tb),
  105. transpose_b=(ta or (not tb)))
  106. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  107. transpose_b=(ta and tb))
  108. def bprop(x, w, out, dout):
  109. if ta:
  110. dx = mul1(w, dout)
  111. else:
  112. dx = mul1(dout, w)
  113. if tb:
  114. dw = mul2(dout, x)
  115. else:
  116. dw = mul2(x, dout)
  117. return dx, dw
  118. return bprop
  119. @bprop_getters.register(P.BatchMatMul)
  120. def bprop_batchmatmul(self):
  121. """Grad definition for `BatchMatMul` operation."""
  122. ta = self.transpose_a
  123. tb = self.transpose_b
  124. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  125. transpose_b=(ta or (not tb)))
  126. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  127. transpose_b=(ta and tb))
  128. def bprop(x, w, out, dout):
  129. if ta:
  130. dx = mul1(w, dout)
  131. else:
  132. dx = mul1(dout, w)
  133. if tb:
  134. dw = mul2(dout, x)
  135. else:
  136. dw = mul2(x, dout)
  137. return dx, dw
  138. return bprop
  139. @bprop_getters.register(P.TensorAdd)
  140. def get_bprop_tensor_add(self):
  141. """Grad definition for `TensorAdd` operation."""
  142. def bprop(x, y, out, dout):
  143. return binop_grad_common(x, y, dout, dout)
  144. return bprop
  145. @bprop_getters.register(P.Neg)
  146. def get_bprop_neg(self):
  147. """Grad definition for `Neg` operation."""
  148. neg_grad = P.Neg()
  149. def bprop(x, out, dout):
  150. dx = neg_grad(dout)
  151. return (dx,)
  152. return bprop
  153. @bprop_getters.register(P.Sub)
  154. def get_bprop_sub(self):
  155. """Grad definition for `Sub` operation."""
  156. neg_func = P.Neg()
  157. def bprop(x, y, out, dout):
  158. return binop_grad_common(x, y, dout, neg_func(dout))
  159. return bprop
  160. @bprop_getters.register(P.Mul)
  161. def get_bprop_mul(self):
  162. """Grad definition for `Mul` operation."""
  163. mul_func = P.Mul()
  164. def bprop(x, y, out, dout):
  165. bc_dx = mul_func(y, dout)
  166. bc_dy = mul_func(x, dout)
  167. return binop_grad_common(x, y, bc_dx, bc_dy)
  168. return bprop
  169. @bprop_getters.register(P.RealDiv)
  170. def get_bprop_real_div(self):
  171. """Grad definition for `RealDiv` operation."""
  172. div_op = P.RealDiv()
  173. neg = P.Neg()
  174. mul_op = P.Mul()
  175. def bprop(x, y, out, dout):
  176. bc_x = div_op(dout, y)
  177. bc_y = neg(mul_op(bc_x, out))
  178. return binop_grad_common(x, y, bc_x, bc_y)
  179. return bprop
  180. @bprop_getters.register(P.Div)
  181. def get_bprop_div(self):
  182. """Grad definition for `Div` operation."""
  183. div_op = P.Div()
  184. neg = P.Neg()
  185. mul_op = P.Mul()
  186. def bprop(x, y, out, dout):
  187. bc_x = div_op(dout, y)
  188. bc_y = neg(mul_op(bc_x, out))
  189. return binop_grad_common(x, y, bc_x, bc_y)
  190. return bprop
  191. @bprop_getters.register(P.DivNoNan)
  192. def get_bprop_div_no_nan(self):
  193. """Grad definition for `DivNoNan` operation."""
  194. div_no_nan_op = P.DivNoNan()
  195. neg = P.Neg()
  196. mul_op = P.Mul()
  197. def bprop(x, y, out, dout):
  198. bc_x = div_no_nan_op(dout, y)
  199. bc_y = neg(mul_op(bc_x, out))
  200. return binop_grad_common(x, y, bc_x, bc_y)
  201. return bprop
  202. @bprop_getters.register(P.Floor)
  203. def get_bprop_floor(self):
  204. """Grad definition for `floor` operation."""
  205. fill_ = P.Fill()
  206. shape_ = P.Shape()
  207. dtype_ = P.DType()
  208. def bprop(x, out, dout):
  209. bc_x = fill_(dtype_(x), shape_(x), 0.)
  210. return (bc_x,)
  211. return bprop
  212. @bprop_getters.register(P.Ceil)
  213. def get_bprop_ceil(self):
  214. """Grad definition for `ceil` operation."""
  215. fill_ = P.Fill()
  216. shape_ = P.Shape()
  217. dtype_ = P.DType()
  218. def bprop(x, out, dout):
  219. bc_x = fill_(dtype_(x), shape_(x), 0.)
  220. return (bc_x,)
  221. return bprop
  222. @bprop_getters.register(P.FloorDiv)
  223. def get_bprop_floordiv(self):
  224. """Grad definition for `FloorDiv` operation."""
  225. div_op = P.FloorDiv()
  226. neg = P.Neg()
  227. mul_op = P.Mul()
  228. def bprop(x, y, out, dout):
  229. bc_x = div_op(dout, y)
  230. bc_y = neg(mul_op(bc_x, out))
  231. return binop_grad_common(x, y, bc_x, bc_y)
  232. return bprop
  233. @bprop_getters.register(P.FloorMod)
  234. def get_bprop_floormod(self):
  235. """Grad definition for `FloorMod` operation."""
  236. def bprop(x, y, out, dout):
  237. bc_x = dout
  238. bc_y = -dout * (x // y)
  239. return binop_grad_common(x, y, bc_x, bc_y)
  240. return bprop
  241. @bprop_getters.register(P.Mod)
  242. def get_bprop_mod(self):
  243. """Grad definition for `Mod` operation."""
  244. def bprop(x, y, out, dout):
  245. bc_x = dout
  246. bc_y = -dout * (x // y)
  247. return binop_grad_common(x, y, bc_x, bc_y)
  248. return bprop
  249. @bprop_getters.register(P.Square)
  250. def get_bprop_square(self):
  251. """Grad definition for `Square` operation."""
  252. mul_func = P.Mul()
  253. fill_func = P.Fill()
  254. dtype = P.DType()
  255. def bprop(x, out, dout):
  256. temp = mul_func(dout, x)
  257. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  258. return (dx,)
  259. return bprop
  260. @bprop_getters.register(P.Sqrt)
  261. def get_bprop_sqrt(self):
  262. """Grad definition for `Sqrt` operation."""
  263. mul_func = P.Mul()
  264. fill_func = P.Fill()
  265. div_op = P.RealDiv()
  266. sqrt = P.Sqrt()
  267. dtype = P.DType()
  268. def bprop(x, out, dout):
  269. temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
  270. dx = mul_func(dout, temp)
  271. return (dx,)
  272. return bprop
  273. @bprop_getters.register(P.Rsqrt)
  274. def get_bprop_rsqrt(self):
  275. """Grad definition for `Rsqrt` operation."""
  276. def bprop(x, out, dout):
  277. grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x)
  278. dx = dout * grad
  279. return (dx,)
  280. return bprop
  281. @bprop_getters.register(P.Reciprocal)
  282. def get_bprop_reciprocal(self):
  283. """Grad definition for `Reciprocal` operation."""
  284. neg = P.Neg()
  285. mul = P.Mul()
  286. square = P.Square()
  287. reciprocal = P.Reciprocal()
  288. def bprop(x, out, dout):
  289. g = neg(reciprocal(square(x)))
  290. dx = mul(dout, g)
  291. return (dx,)
  292. return bprop
  293. @bprop_getters.register(P.Log)
  294. def get_bprop_log(self):
  295. """Grad definition for `Log` operation."""
  296. reciprocal = P.Reciprocal()
  297. def bprop(x, out, dout):
  298. g = reciprocal(x)
  299. dx = g * dout
  300. return dx, 0
  301. return bprop
  302. @bprop_getters.register(P.Log1p)
  303. def get_bprop_log1p(self):
  304. """Grad definition for `Log1p` operation."""
  305. reciprocal = P.Reciprocal()
  306. def bprop(x, out, dout):
  307. x_1p = x + 1
  308. g = reciprocal(x_1p)
  309. dx = g * dout
  310. return dx, 0
  311. return bprop
  312. @bprop_getters.register(P.Erf)
  313. def get_bprop_erf(self):
  314. """Grad definition for `Erf` operation."""
  315. exp = P.Exp()
  316. square = P.Square()
  317. sqrt = P.Sqrt()
  318. cast = P.Cast()
  319. dtype = P.DType()
  320. def bprop(x, out, dout):
  321. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  322. x_square = square(x)
  323. dx = dout * half_root_pi * exp(-x_square)
  324. return (dx,)
  325. return bprop
  326. @bprop_getters.register(P.Erfc)
  327. def get_bprop_erfc(self):
  328. """Grad definition for `Erfc` operation."""
  329. exp = P.Exp()
  330. square = P.Square()
  331. sqrt = P.Sqrt()
  332. cast = P.Cast()
  333. dtype = P.DType()
  334. def bprop(x, out, dout):
  335. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  336. x_square = square(x)
  337. dx = dout * (-half_root_pi * exp(-x_square))
  338. return (dx,)
  339. return bprop
  340. @bprop_getters.register(P.Pow)
  341. def get_bprop_pow(self):
  342. """Grad definition for `Pow` operation."""
  343. pow_op = P.Pow()
  344. ln = P.Log()
  345. def bprop(x, power, out, dout):
  346. bc_dx = power * pow_op(x, power - 1.0) * dout
  347. bc_dpower = out * ln(x) * dout
  348. return binop_grad_common(x, power, bc_dx, bc_dpower)
  349. return bprop
  350. @bprop_getters.register(P.Exp)
  351. def get_bprop_exp(self):
  352. """Grad definition for `Exp` operation."""
  353. exp_ = P.Exp()
  354. def bprop(x, out, dout):
  355. g = exp_(x)
  356. dx = g * dout
  357. return (dx,)
  358. return bprop
  359. @bprop_getters.register(P.Expm1)
  360. def get_bprop_expm1(self):
  361. """Grad definition for `Expm1` operation."""
  362. exp_ = P.Exp()
  363. def bprop(x, out, dout):
  364. g = exp_(x)
  365. dx = g * dout
  366. return (dx,)
  367. return bprop
  368. @bprop_getters.register(P.Minimum)
  369. def get_bprop_minimum(self):
  370. """Grad definition for `Minimum` operation."""
  371. input_grad = SG.MinimumGrad()
  372. def bprop(x, y, out, dout):
  373. dx, dy = input_grad(x, y, dout)
  374. return dx, dy
  375. return bprop
  376. @bprop_getters.register(P.Maximum)
  377. def get_bprop_maximum(self):
  378. """Grad definition for `Maximum` operation."""
  379. input_grad = SG.MaximumGrad()
  380. def bprop(x, y, out, dout):
  381. dx, dy = input_grad(x, y, dout)
  382. return dx, dy
  383. return bprop
  384. @bprop_getters.register(P.ReduceSum)
  385. def get_bprop_reducesum(self):
  386. """Grad definition for `ReduceSum` operation."""
  387. def bprop(x, axis, out, dout):
  388. dx = _sum_grad(x, axis, dout)
  389. return dx, zeros_like(axis)
  390. return bprop
  391. @bprop_getters.register(P.CumSum)
  392. def get_bprop_cumsum(self):
  393. """Grad definition for `CumSum` operation."""
  394. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  395. def bprop(x, axis, out, dout):
  396. return cumsum(dout, axis), zeros_like(axis)
  397. return bprop
  398. @constexpr
  399. def _split_shape_index(input_shape, axis):
  400. """Calculate reduce_prod grad transpose indices and perm shape."""
  401. rank = len(input_shape)
  402. if isinstance(axis, int):
  403. axis = tuple([axis])
  404. reduction_indices = tuple([(i + rank) % rank for i in axis])
  405. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  406. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  407. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  408. perm = reduction_indices + other_indices
  409. return tuple([reduced_num, other_num]), perm
  410. @constexpr
  411. def _invert_permutation(perm):
  412. """Calculate invert permutation."""
  413. out = [0] * len(perm)
  414. for i, value in enumerate(perm):
  415. out[value] = i
  416. return tuple(out)
  417. @bprop_getters.register(P.ReduceProd)
  418. def get_bprop_reduceprod(self):
  419. """Grad definition for `ReduceProd` operation."""
  420. transpose = P.Transpose()
  421. left_cumprod = P.CumProd(exclusive=True)
  422. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  423. def bprop(x, axis, out, dout):
  424. """Grad definition for `Product` operation."""
  425. # Expand dout to full input shape
  426. input_shape = shape_op(x)
  427. output_shape_kept_dims = reduced_shape(input_shape, axis)
  428. dout = reshape(dout, output_shape_kept_dims)
  429. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  430. grad = tile(dout, tile_scaling)
  431. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  432. pack_shape, perm = _split_shape_index(input_shape, axis)
  433. permuted = transpose(x, perm)
  434. permuted_shape = shape_op(permuted)
  435. reshaped = reshape(permuted, pack_shape)
  436. # Calculate product, leaving out the current entry
  437. left = left_cumprod(reshaped, 0)
  438. right = right_cumprod(reshaped, 0)
  439. y = reshape(left * right, permuted_shape)
  440. # Invert the transpose and reshape operations.
  441. # Make sure to set the statically known shape information through a reshape.
  442. out = transpose(y, _invert_permutation(perm)) * grad
  443. dx = reshape(out, input_shape)
  444. return dx, zeros_like(axis)
  445. return bprop
  446. @bprop_getters.register(P.CumProd)
  447. def get_bprop_cumprod(self):
  448. """Grad definition for `CumProd` operation."""
  449. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  450. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  451. def bprop(x, axis, out, dout):
  452. """Grad definition for `Product` operation."""
  453. # This will fails when x contains 0
  454. prod = cumprod(x, axis)
  455. out = cumsum(prod * dout, axis)
  456. return out / x, zeros_like(axis)
  457. return bprop
  458. @bprop_getters.register(P.ReduceAll)
  459. def get_bprop_reduceall(self):
  460. """Grad definition for `ReduceAll` operation."""
  461. def bprop(x, axis, out, dout):
  462. return zeros_like(x), zeros_like(axis)
  463. return bprop
  464. @bprop_getters.register(P.ReduceMax)
  465. def get_bprop_reducemax(self):
  466. """Grad definition for `Max` operation."""
  467. def bprop(x, axis, out, dout):
  468. dx = _min_or_max_grad(x, axis, out, dout)
  469. return (dx, zeros_like(axis))
  470. return bprop
  471. @bprop_getters.register(P.ArgMaxWithValue)
  472. def get_bprop_argmaxwithvalue(self):
  473. """Grad definition for `ArgMaxWithValue` operation."""
  474. axis = self.axis
  475. keep_dims = self.keep_dims
  476. op = P.ArgMaxWithValue(axis)
  477. def bprop(x, out, dout):
  478. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  479. return (dx,)
  480. return bprop
  481. @bprop_getters.register(P.ReduceMin)
  482. def get_bprop_reducemin(self):
  483. """Grad definition for `ReduceMin` operation."""
  484. def bprop(x, axis, out, dout):
  485. dx = _min_or_max_grad(x, axis, out, dout)
  486. return (dx, zeros_like(axis))
  487. return bprop
  488. @bprop_getters.register(P.ArgMinWithValue)
  489. def get_bprop_argminwithvalue(self):
  490. """Generate bprop for ArgMinWithValue"""
  491. axis = self.axis
  492. keep_dims = self.keep_dims
  493. op = P.ArgMinWithValue(axis)
  494. def bprop(x, out, dout):
  495. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  496. return (dx,)
  497. return bprop
  498. @bprop_getters.register(P.ReduceMean)
  499. def get_bprop_reduce_mean(self):
  500. """Grad definition for `ReduceMean` operation."""
  501. div_op = P.RealDiv()
  502. cast = P.Cast()
  503. dtype = P.DType()
  504. def bprop(x, axis, out, dout):
  505. grad = _sum_grad(x, axis, dout)
  506. div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out))
  507. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  508. return dx, zeros_like(axis)
  509. return bprop
  510. @bprop_getters.register(P.Equal)
  511. def get_bprop_equal(self):
  512. """Grad definition for `Equal` operation."""
  513. def bprop(x, y, out, dout):
  514. return zeros_like(x), zeros_like(y)
  515. return bprop
  516. @bprop_getters.register(P.NotEqual)
  517. def get_bprop_not_equal(self):
  518. """Grad definition for `NotEqual` operation."""
  519. def bprop(x, y, out, dout):
  520. return zeros_like(x), zeros_like(y)
  521. return bprop
  522. @bprop_getters.register(P.ApproximateEqual)
  523. def get_bprop_approximate_equal(self):
  524. """Grad definition for `ApproximateEqual` operation."""
  525. def bprop(x, y, out, dout):
  526. return zeros_like(x), zeros_like(y)
  527. return bprop
  528. @bprop_getters.register(P.Greater)
  529. def get_bprop_greater(self):
  530. """Grad definition for `Greater` operation."""
  531. def bprop(x, y, out, dout):
  532. return zeros_like(x), zeros_like(y)
  533. return bprop
  534. @bprop_getters.register(P.GreaterEqual)
  535. def get_bprop_greater_equal(self):
  536. """Grad definition for `GreaterEqual` operation."""
  537. def bprop(x, y, out, dout):
  538. return zeros_like(x), zeros_like(y)
  539. return bprop
  540. @bprop_getters.register(P.Less)
  541. def get_bprop_less(self):
  542. """Grad definition for `Less` operation."""
  543. def bprop(x, y, out, dout):
  544. return zeros_like(x), zeros_like(y)
  545. return bprop
  546. @bprop_getters.register(P.LessEqual)
  547. def get_bprop_less_equal(self):
  548. """Grad definition for `LessEqual` operation."""
  549. def bprop(x, y, out, dout):
  550. return zeros_like(x), zeros_like(y)
  551. return bprop
  552. @bprop_getters.register(P.LogicalNot)
  553. def get_bprop_logical_not(self):
  554. """Grad definition for `LogicalNot` operation."""
  555. def bprop(x, out, dout):
  556. return (zeros_like(x),)
  557. return bprop
  558. @bprop_getters.register(P.LogicalAnd)
  559. def get_bprop_logical_and(self):
  560. """Grad definition for `LogicalAnd` operation."""
  561. def bprop(x, y, out, dout):
  562. return zeros_like(x), zeros_like(y)
  563. return bprop
  564. @bprop_getters.register(P.LogicalOr)
  565. def get_bprop_logical_or(self):
  566. """Grad definition for `LogicalOr` operation."""
  567. def bprop(x, y, out, dout):
  568. return zeros_like(x), zeros_like(y)
  569. return bprop
  570. @bprop_getters.register(P.NPUAllocFloatStatus)
  571. def get_bprop_npu_alloc_float_status(self):
  572. """Grad definition for `NPUAllocFloatStatus` operation."""
  573. def bprop(out, dout):
  574. return ()
  575. return bprop
  576. @bprop_getters.register(P.NPUGetFloatStatus)
  577. def get_bprop_npu_get_float_status(self):
  578. """Grad definition for `NPUGetFloatStatus` operation."""
  579. def bprop(x, out, dout):
  580. return (zeros_like(x),)
  581. return bprop
  582. @bprop_getters.register(P.NPUClearFloatStatus)
  583. def get_bprop_npu_clear_float_status(self):
  584. """Grad definition for `NPUClearFloatStatus` operation."""
  585. def bprop(x, out, dout):
  586. return (zeros_like(x),)
  587. return bprop
  588. @bprop_getters.register(P.AssignAdd)
  589. def get_bprop_assign_add(self):
  590. """Grad definition for `AssignAdd` operation."""
  591. def bprop(x, y, out, dout):
  592. return zeros_like(x), zeros_like(y)
  593. return bprop
  594. @bprop_getters.register(P.AssignSub)
  595. def get_bprop_assign_sub(self):
  596. """Grad definition for `AssignSub` operation."""
  597. def bprop(x, y, out, dout):
  598. return zeros_like(x), zeros_like(y)
  599. return bprop
  600. @bprop_getters.register(P.Sin)
  601. def get_bprop_sin(self):
  602. """Grad definition for `Sin` operation."""
  603. cos = P.Cos()
  604. def bprop(x, out, dout):
  605. dx = dout * cos(x)
  606. return (dx,)
  607. return bprop
  608. @bprop_getters.register(P.Asin)
  609. def get_bprop_asin(self):
  610. """Grad definition for `Asin` operation."""
  611. input_grad = G.AsinGrad()
  612. def bprop(x, out, dout):
  613. dx = input_grad(x, dout)
  614. return (dx,)
  615. return bprop
  616. @bprop_getters.register(P.Asinh)
  617. def get_bprop_asinh(self):
  618. """Grad definition for `Asinh` operation."""
  619. input_grad = G.AsinhGrad()
  620. def bprop(x, out, dout):
  621. dx = input_grad(out, dout)
  622. return (dx,)
  623. return bprop
  624. @bprop_getters.register(P.Sinh)
  625. def get_bprop_sinh(self):
  626. """Grad definition for `Sinh` operation."""
  627. cosh = P.Cosh()
  628. def bprop(x, out, dout):
  629. dx = cosh(x) * dout
  630. return (dx,)
  631. return bprop
  632. @bprop_getters.register(P.Cos)
  633. def get_bprop_cos(self):
  634. """Grad definition for `Cos` operation."""
  635. sin = P.Sin()
  636. neg = P.Neg()
  637. def bprop(x, out, dout):
  638. dx = dout * neg(sin(x))
  639. return (dx,)
  640. return bprop
  641. @bprop_getters.register(P.ACos)
  642. def get_bprop_acos(self):
  643. """Grad definition for `ACos` operation."""
  644. input_grad = G.ACosGrad()
  645. def bprop(x, out, dout):
  646. dx = input_grad(x, dout)
  647. return (dx,)
  648. return bprop
  649. @bprop_getters.register(P.Acosh)
  650. def get_bprop_acosh(self):
  651. """Grad definition for `Acosh` operation."""
  652. input_grad = G.AcoshGrad()
  653. def bprop(x, out, dout):
  654. dx = input_grad(out, dout)
  655. return (dx,)
  656. return bprop
  657. @bprop_getters.register(P.Cosh)
  658. def get_bprop_cosh(self):
  659. """Grad definition for `Cosh` operation."""
  660. sinh = P.Sinh()
  661. def bprop(x, out, dout):
  662. dx = sinh(x) * dout
  663. return (dx,)
  664. return bprop
  665. @bprop_getters.register(P.Abs)
  666. def get_bprop_abs(self):
  667. """Grad definition for `Abs` operation."""
  668. abs_grad = SG.AbsGrad()
  669. def bprop(x, out, dout):
  670. dx = abs_grad(x, dout)
  671. return (dx,)
  672. return bprop
  673. @bprop_getters.register(P.ScalarCast)
  674. def get_bprop_scalar_cast(self):
  675. """Generate bprop for ScalarCast"""
  676. def bprop(x, t, out, dout):
  677. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  678. return bprop
  679. @bprop_getters.register(P.AccumulateNV2)
  680. def get_bprop_scalar_accumulatenv2(self):
  681. """Generate bprop for AccumulateNV2"""
  682. def bprop(x, out, dout):
  683. dx = ()
  684. for _ in range(len(x)):
  685. dx = dx + (dout,)
  686. return dx
  687. return bprop
  688. @bprop_getters.register(P.AddN)
  689. def get_bprop_scalar_addn(self):
  690. """Generate bprop for AddN"""
  691. def bprop(x, out, dout):
  692. dx = ()
  693. for _ in range(len(x)):
  694. dx = dx + (dout,)
  695. return dx
  696. return bprop
  697. @bprop_getters.register(P.Sign)
  698. def get_bprop_sign(self):
  699. """Generate bprop for Sign"""
  700. def bprop(x, out, dout):
  701. return (zeros_like(x),)
  702. return bprop
  703. @bprop_getters.register(P.Round)
  704. def get_bprop_round(self):
  705. """Generate bprop for Round"""
  706. def bprop(x, out, dout):
  707. return (zeros_like(x),)
  708. return bprop
  709. @bprop_getters.register(P.Atan2)
  710. def get_bprop_atan2(self):
  711. """Generate bprop for Atan2"""
  712. square = P.Square()
  713. def bprop(x, y, out, dout):
  714. tmp = dout / (square(x) + square(y))
  715. bc_dx = tmp * y
  716. bc_dy = tmp * (-x)
  717. return binop_grad_common(x, y, bc_dx, bc_dy)
  718. return bprop
  719. @bprop_getters.register(P.BesselI0e)
  720. def get_bprop_bessel_i0e(self):
  721. """Generate bprop for BesselI0e"""
  722. sign = P.Sign()
  723. bessel_i1e = P.BesselI1e()
  724. def bprop(x, out, dout):
  725. dx = dout * (bessel_i1e(x) - sign(x) * out)
  726. return (dx,)
  727. return bprop
  728. @bprop_getters.register(P.Atan)
  729. def get_bprop_atan(self):
  730. """Grad definition for `Atan` operation."""
  731. input_grad = G.AtanGrad()
  732. def bprop(x, out, dout):
  733. dx = input_grad(x, dout)
  734. return (dx,)
  735. return bprop
  736. @bprop_getters.register(P.BesselI1e)
  737. def get_bprop_bessel_i1e(self):
  738. """Generate bprop for BesselI1e"""
  739. sign = P.Sign()
  740. bessel_i0e = P.BesselI0e()
  741. less = P.Less()
  742. select = P.Select()
  743. reciprocal = P.Reciprocal()
  744. cast = P.Cast()
  745. dtype = P.DType()
  746. abs_ops = P.Abs()
  747. def bprop(x, out, dout):
  748. zeros = zeros_like(x)
  749. np_eps = const_utils.get_np_eps(dtype(x))
  750. eps = cast(np_eps, dtype(x))
  751. x_is_valid = less(eps, abs_ops(x))
  752. x_safe = select(x_is_valid, x, eps + zeros)
  753. tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
  754. dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
  755. return (dx,)
  756. return bprop
  757. @bprop_getters.register(P.Atanh)
  758. def get_bprop_atanh(self):
  759. """Grad definition for `Atanh` operation."""
  760. power = P.Pow()
  761. div = P.Div()
  762. def bprop(x, out, dout):
  763. tmp = 1 - power(x, 2)
  764. dx = div(1, tmp) * dout
  765. return (dx,)
  766. return bprop
  767. @bprop_getters.register(P.Inv)
  768. def get_bprop_inv(self):
  769. """Grad definition for 'Inv' operation"""
  770. inv_grad = G.InvGrad()
  771. def bprop(x, out, dout):
  772. dx = inv_grad(out, dout)
  773. return (dx,)
  774. return bprop
  775. @bprop_getters.register(inner.LinSpace)
  776. def get_bprop_lin_space(self):
  777. """Grad definition for `LinSpace` operation."""
  778. def bprop(assist, start, stop, num, out, dout):
  779. return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num)
  780. return bprop