You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. from .. import functional as F
  19. from .. import operations as P
  20. from ..operations import _grad_ops as G
  21. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  22. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  23. from .grad_base import bprop_getters
  24. from ..primitive import constexpr
  25. shape_op = P.Shape()
  26. reduce_sum = P.ReduceSum()
  27. reshape = P.Reshape()
  28. tile = P.Tile()
  29. def binop_grad_common(x, y, dx, dy):
  30. """
  31. Common grad definition for binary operations.
  32. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  33. """
  34. shape_of_x = shape_op(x)
  35. shape_of_y = shape_op(y)
  36. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  37. # if input shape is the same as dout shape, do not need to reduce
  38. reduce_dx = dx
  39. reduce_dy = dy
  40. if rx[0]:
  41. # if dx is scalar whose shape is (), do not need reduce
  42. if shape_op(dx):
  43. dx = reduce_sum(dx, rx[0])
  44. reduce_dx = reshape(dx, shape_of_x)
  45. if rx[1]:
  46. # if dy is scalar whose shape is (), do not need reduce
  47. if shape_op(dy):
  48. dy = reduce_sum(dy, rx[1])
  49. reduce_dy = reshape(dy, shape_of_y)
  50. return reduce_dx, reduce_dy
  51. def _sum_grad(x, axis, dout):
  52. """Grad definition for `Sum` operation."""
  53. # input_shape = [2, 3] axis = [1]
  54. input_shape = shape_op(x)
  55. # output_shape_kept_dims = [2, 1]
  56. output_shape_kept_dims = reduced_shape(input_shape, axis)
  57. # tile_scaling = [1, 3]
  58. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  59. grad = reshape(dout, output_shape_kept_dims)
  60. return tile(grad, tile_scaling)
  61. def _min_or_max_grad(x, axis, out, dout):
  62. """Grad definition for `Min` and `Max` operations."""
  63. # input_shape = [2, 3] axis = [1]
  64. input_shape = shape_op(x)
  65. # output_shape_kept_dims = [2, 1]
  66. output_shape_kept_dims = reduced_shape(input_shape, axis)
  67. y = reshape(out, output_shape_kept_dims)
  68. grad = reshape(dout, output_shape_kept_dims)
  69. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  70. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  71. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  72. return indicators / num_selected * grad
  73. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  74. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  75. expand = P.ExpandDims()
  76. x_shape = F.shape(x)
  77. x_dim = len(x_shape)
  78. x_axis = axis
  79. if x_axis < 0:
  80. x_axis = axis + x_dim
  81. onehot_axis = x_axis
  82. depth = x_shape[x_axis]
  83. if keep_dims:
  84. dout_expand = dout[1]
  85. out = op(x)
  86. else:
  87. dout_expand = expand(dout[1], onehot_axis)
  88. if onehot_axis >= len(shape_op(out[0])):
  89. onehot_axis = -1
  90. onehot = P.OneHot(onehot_axis)
  91. type_x = F.dtype(x)
  92. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  93. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  94. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  95. return dx
  96. @bprop_getters.register(P.MatMul)
  97. def bprop_matmul(self):
  98. """Grad definition for `MatMul` operation."""
  99. ta = self.transpose_a
  100. tb = self.transpose_b
  101. mul1 = P.MatMul(transpose_a=(ta and tb),
  102. transpose_b=(ta or (not tb)))
  103. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  104. transpose_b=(ta and tb))
  105. def bprop(x, w, out, dout):
  106. if ta:
  107. dx = mul1(w, dout)
  108. else:
  109. dx = mul1(dout, w)
  110. if tb:
  111. dw = mul2(dout, x)
  112. else:
  113. dw = mul2(x, dout)
  114. return dx, dw
  115. return bprop
  116. @bprop_getters.register(P.BatchMatMul)
  117. def bprop_batchmatmul(self):
  118. """Grad definition for `BatchMatMul` operation."""
  119. ta = self.transpose_a
  120. tb = self.transpose_b
  121. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  122. transpose_b=(ta or (not tb)))
  123. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  124. transpose_b=(ta and tb))
  125. def bprop(x, w, out, dout):
  126. if ta:
  127. dx = mul1(w, dout)
  128. else:
  129. dx = mul1(dout, w)
  130. if tb:
  131. dw = mul2(dout, x)
  132. else:
  133. dw = mul2(x, dout)
  134. return dx, dw
  135. return bprop
  136. @bprop_getters.register(P.TensorAdd)
  137. def get_bprop_tensor_add(self):
  138. """Grad definition for `TensorAdd` operation."""
  139. def bprop(x, y, out, dout):
  140. return binop_grad_common(x, y, dout, dout)
  141. return bprop
  142. @bprop_getters.register(P.Neg)
  143. def get_bprop_neg(self):
  144. """Grad definition for `Neg` operation."""
  145. neg_grad = P.Neg()
  146. def bprop(x, out, dout):
  147. dx = neg_grad(dout)
  148. return (dx,)
  149. return bprop
  150. @bprop_getters.register(P.Sub)
  151. def get_bprop_sub(self):
  152. """Grad definition for `Sub` operation."""
  153. neg_func = P.Neg()
  154. def bprop(x, y, out, dout):
  155. return binop_grad_common(x, y, dout, neg_func(dout))
  156. return bprop
  157. @bprop_getters.register(P.Mul)
  158. def get_bprop_mul(self):
  159. """Grad definition for `Mul` operation."""
  160. mul_func = P.Mul()
  161. def bprop(x, y, out, dout):
  162. bc_dx = mul_func(dout, y)
  163. bc_dy = mul_func(dout, x)
  164. return binop_grad_common(x, y, bc_dx, bc_dy)
  165. return bprop
  166. @bprop_getters.register(P.RealDiv)
  167. def get_bprop_real_div(self):
  168. """Grad definition for `RealDiv` operation."""
  169. div_op = P.RealDiv()
  170. neg = P.Neg()
  171. mul_op = P.Mul()
  172. def bprop(x, y, out, dout):
  173. bc_x = div_op(dout, y)
  174. bc_y = neg(mul_op(bc_x, out))
  175. return binop_grad_common(x, y, bc_x, bc_y)
  176. return bprop
  177. @bprop_getters.register(P.Div)
  178. def get_bprop_div(self):
  179. """Grad definition for `Div` operation."""
  180. div_op = P.Div()
  181. neg = P.Neg()
  182. mul_op = P.Mul()
  183. def bprop(x, y, out, dout):
  184. bc_x = div_op(dout, y)
  185. bc_y = neg(mul_op(bc_x, out))
  186. return binop_grad_common(x, y, bc_x, bc_y)
  187. return bprop
  188. @bprop_getters.register(P.Floor)
  189. def get_bprop_floor(self):
  190. """Grad definition for `floor` operation."""
  191. fill_ = P.Fill()
  192. shape_ = P.Shape()
  193. dtype_ = P.DType()
  194. def bprop(x, out, dout):
  195. bc_x = fill_(dtype_(x), shape_(x), 0.)
  196. return (bc_x,)
  197. return bprop
  198. @bprop_getters.register(P.FloorDiv)
  199. def get_bprop_floordiv(self):
  200. """Grad definition for `FloorDiv` operation."""
  201. div_op = P.FloorDiv()
  202. neg = P.Neg()
  203. mul_op = P.Mul()
  204. def bprop(x, y, out, dout):
  205. bc_x = div_op(dout, y)
  206. bc_y = neg(mul_op(bc_x, out))
  207. return binop_grad_common(x, y, bc_x, bc_y)
  208. return bprop
  209. @bprop_getters.register(P.FloorMod)
  210. def get_bprop_floormod(self):
  211. """Grad definition for `FloorMod` operation."""
  212. div_op = P.FloorMod()
  213. neg = P.Neg()
  214. mul_op = P.Mul()
  215. def bprop(x, y, out, dout):
  216. bc_x = div_op(dout, y)
  217. bc_y = neg(mul_op(bc_x, out))
  218. return binop_grad_common(x, y, bc_x, bc_y)
  219. return bprop
  220. @bprop_getters.register(P.Square)
  221. def get_bprop_square(self):
  222. """Grad definition for `Square` operation."""
  223. mul_func = P.Mul()
  224. fill_func = P.Fill()
  225. dtype = P.DType()
  226. def bprop(x, out, dout):
  227. temp = mul_func(dout, x)
  228. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  229. return (dx,)
  230. return bprop
  231. @bprop_getters.register(P.Sqrt)
  232. def get_bprop_sqrt(self):
  233. """Grad definition for `Sqrt` operation."""
  234. mul_func = P.Mul()
  235. fill_func = P.Fill()
  236. div_op = P.RealDiv()
  237. sqrt = P.Sqrt()
  238. dtype = P.DType()
  239. def bprop(x, out, dout):
  240. temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
  241. dx = mul_func(dout, temp)
  242. return (dx,)
  243. return bprop
  244. @bprop_getters.register(P.Rsqrt)
  245. def get_bprop_rsqrt(self):
  246. """Grad definition for `Rsqrt` operation."""
  247. def bprop(x, out, dout):
  248. grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x)*x)
  249. dx = dout * grad
  250. return (dx,)
  251. return bprop
  252. @bprop_getters.register(P.Reciprocal)
  253. def get_bprop_reciprocal(self):
  254. """Grad definition for `Reciprocal` operation."""
  255. neg = P.Neg()
  256. mul = P.Mul()
  257. square = P.Square()
  258. reciprocal = P.Reciprocal()
  259. def bprop(x, out, dout):
  260. g = neg(reciprocal(square(x)))
  261. dx = mul(dout, g)
  262. return (dx,)
  263. return bprop
  264. @bprop_getters.register(P.Log)
  265. def get_bprop_log(self):
  266. """Grad definition for `Log` operation."""
  267. reciprocal = P.Reciprocal()
  268. def bprop(x, out, dout):
  269. g = reciprocal(x)
  270. dx = g * dout
  271. return dx, 0
  272. return bprop
  273. @bprop_getters.register(P.Erf)
  274. def get_bprop_erf(self):
  275. """Grad definition for `Erf` operation."""
  276. exp = P.Exp()
  277. square = P.Square()
  278. sqrt = P.Sqrt()
  279. cast = P.Cast()
  280. dtype = P.DType()
  281. def bprop(x, out, dout):
  282. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  283. x_square = square(x)
  284. dx = dout * half_root_pi * exp(-x_square)
  285. return (dx,)
  286. return bprop
  287. @bprop_getters.register(P.Pow)
  288. def get_bprop_pow(self):
  289. """Grad definition for `Pow` operation."""
  290. pow_op = P.Pow()
  291. ln = P.Log()
  292. def bprop(x, power, out, dout):
  293. bc_dx = power * pow_op(x, power - 1.0) * dout
  294. bc_dpower = out * ln(x) * dout
  295. return binop_grad_common(x, power, bc_dx, bc_dpower)
  296. return bprop
  297. @bprop_getters.register(P.Exp)
  298. def get_bprop_exp(self):
  299. """Grad definition for `Exp` operation."""
  300. exp_ = P.Exp()
  301. def bprop(x, out, dout):
  302. g = exp_(x)
  303. dx = g * dout
  304. return (dx,)
  305. return bprop
  306. @bprop_getters.register(P.Minimum)
  307. def get_bprop_minimum(self):
  308. """Grad definition for `Minimum` operation."""
  309. input_grad = G.MinimumGrad()
  310. def bprop(x, y, out, dout):
  311. dx, dy = input_grad(x, y, dout)
  312. return dx, dy
  313. return bprop
  314. @bprop_getters.register(P.Maximum)
  315. def get_bprop_maximum(self):
  316. """Grad definition for `Maximum` operation."""
  317. input_grad = G.MaximumGrad()
  318. def bprop(x, y, out, dout):
  319. dx, dy = input_grad(x, y, dout)
  320. return dx, dy
  321. return bprop
  322. @bprop_getters.register(P.ReduceSum)
  323. def get_bprop_reducesum(self):
  324. """Grad definition for `ReduceSum` operation."""
  325. def bprop(x, axis, out, dout):
  326. dx = _sum_grad(x, axis, dout)
  327. return dx, zeros_like(axis)
  328. return bprop
  329. @bprop_getters.register(P.CumSum)
  330. def get_bprop_cumsum(self):
  331. """Grad definition for `CumSum` operation."""
  332. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  333. def bprop(x, axis, out, dout):
  334. return cumsum(dout, axis), zeros_like(axis)
  335. return bprop
  336. @constexpr
  337. def _split_shape_index(input_shape, axis):
  338. """Calculate reduce_prod grad transpose indices and perm shape."""
  339. rank = len(input_shape)
  340. if isinstance(axis, int):
  341. axis = tuple([axis])
  342. reduction_indices = tuple([(i + rank) % rank for i in axis])
  343. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  344. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  345. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  346. perm = reduction_indices + other_indices
  347. return tuple([reduced_num, other_num]), perm
  348. @constexpr
  349. def _invert_permutation(perm):
  350. """Calculate invert permutation."""
  351. out = [0] * len(perm)
  352. for i, value in enumerate(perm):
  353. out[value] = i
  354. return tuple(out)
  355. @bprop_getters.register(P.ReduceProd)
  356. def get_bprop_reduceprod(self):
  357. """Grad definition for `ReduceProd` operation."""
  358. transpose = P.Transpose()
  359. left_cumprod = P.CumProd(exclusive=True)
  360. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  361. def bprop(x, axis, out, dout):
  362. """Grad definition for `Product` operation."""
  363. # Expand dout to full input shape
  364. input_shape = shape_op(x)
  365. output_shape_kept_dims = reduced_shape(input_shape, axis)
  366. dout = reshape(dout, output_shape_kept_dims)
  367. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  368. grad = tile(dout, tile_scaling)
  369. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  370. pack_shape, perm = _split_shape_index(input_shape, axis)
  371. permuted = transpose(x, perm)
  372. permuted_shape = shape_op(permuted)
  373. reshaped = reshape(permuted, pack_shape)
  374. # Calculate product, leaving out the current entry
  375. left = left_cumprod(reshaped, 0)
  376. right = right_cumprod(reshaped, 0)
  377. y = reshape(left * right, permuted_shape)
  378. # Invert the transpose and reshape operations.
  379. # Make sure to set the statically known shape information through a reshape.
  380. out = transpose(y, _invert_permutation(perm)) * grad
  381. dx = reshape(out, input_shape)
  382. return dx, zeros_like(axis)
  383. return bprop
  384. @bprop_getters.register(P.CumProd)
  385. def get_bprop_cumprod(self):
  386. """Grad definition for `CumProd` operation."""
  387. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  388. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  389. def bprop(x, axis, out, dout):
  390. """Grad definition for `Product` operation."""
  391. # This will fails when x contains 0
  392. prod = cumprod(x, axis)
  393. out = cumsum(prod * dout, axis)
  394. return out / x, zeros_like(axis)
  395. return bprop
  396. @bprop_getters.register(P.ReduceAll)
  397. def get_bprop_reduceall(self):
  398. """Grad definition for `ReduceAll` operation."""
  399. def bprop(x, axis, out, dout):
  400. return zeros_like(x), zeros_like(axis)
  401. return bprop
  402. @bprop_getters.register(P.ReduceMax)
  403. def get_bprop_reducemax(self):
  404. """Grad definition for `Max` operation."""
  405. def bprop(x, axis, out, dout):
  406. dx = _min_or_max_grad(x, axis, out, dout)
  407. return (dx, zeros_like(axis))
  408. return bprop
  409. @bprop_getters.register(P.ArgMaxWithValue)
  410. def get_bprop_argmaxwithvalue(self):
  411. """Grad definition for `ArgMaxWithValue` operation."""
  412. axis = self.axis
  413. keep_dims = self.keep_dims
  414. op = P.ArgMaxWithValue(axis)
  415. def bprop(x, out, dout):
  416. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  417. return (dx,)
  418. return bprop
  419. @bprop_getters.register(P.ReduceMin)
  420. def get_bprop_reducemin(self):
  421. """Grad definition for `ReduceMin` operation."""
  422. def bprop(x, axis, out, dout):
  423. dx = _min_or_max_grad(x, axis, out, dout)
  424. return (dx, zeros_like(axis))
  425. return bprop
  426. @bprop_getters.register(P.ArgMinWithValue)
  427. def get_bprop_argminwithvalue(self):
  428. """Generate bprop for ArgMinWithValue"""
  429. axis = self.axis
  430. keep_dims = self.keep_dims
  431. op = P.ArgMinWithValue(axis)
  432. def bprop(x, out, dout):
  433. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  434. return (dx,)
  435. return bprop
  436. @bprop_getters.register(P.ReduceMean)
  437. def get_bprop_reduce_mean(self):
  438. """Grad definition for `ReduceMean` operation."""
  439. div_op = P.RealDiv()
  440. cast = P.Cast()
  441. dtype = P.DType()
  442. def bprop(x, axis, out, dout):
  443. grad = _sum_grad(x, axis, dout)
  444. div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out))
  445. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  446. return dx, zeros_like(axis)
  447. return bprop
  448. @bprop_getters.register(P.Equal)
  449. def get_bprop_equal(self):
  450. """Grad definition for `Equal` operation."""
  451. def bprop(x, y, out, dout):
  452. return zeros_like(x), zeros_like(y)
  453. return bprop
  454. @bprop_getters.register(P.NotEqual)
  455. def get_bprop_not_equal(self):
  456. """Grad definition for `NotEqual` operation."""
  457. def bprop(x, y, out, dout):
  458. return zeros_like(x), zeros_like(y)
  459. return bprop
  460. @bprop_getters.register(P.Greater)
  461. def get_bprop_greater(self):
  462. """Grad definition for `Greater` operation."""
  463. def bprop(x, y, out, dout):
  464. return zeros_like(x), zeros_like(y)
  465. return bprop
  466. @bprop_getters.register(P.GreaterEqual)
  467. def get_bprop_greater_equal(self):
  468. """Grad definition for `GreaterEqual` operation."""
  469. def bprop(x, y, out, dout):
  470. return zeros_like(x), zeros_like(y)
  471. return bprop
  472. @bprop_getters.register(P.Less)
  473. def get_bprop_less(self):
  474. """Grad definition for `Less` operation."""
  475. def bprop(x, y, out, dout):
  476. return zeros_like(x), zeros_like(y)
  477. return bprop
  478. @bprop_getters.register(P.LessEqual)
  479. def get_bprop_less_equal(self):
  480. """Grad definition for `LessEqual` operation."""
  481. def bprop(x, y, out, dout):
  482. return zeros_like(x), zeros_like(y)
  483. return bprop
  484. @bprop_getters.register(P.LogicalNot)
  485. def get_bprop_logical_not(self):
  486. """Grad definition for `LogicalNot` operation."""
  487. def bprop(x, out, dout):
  488. return (zeros_like(x),)
  489. return bprop
  490. @bprop_getters.register(P.LogicalAnd)
  491. def get_bprop_logical_and(self):
  492. """Grad definition for `LogicalAnd` operation."""
  493. def bprop(x, y, out, dout):
  494. return zeros_like(x), zeros_like(y)
  495. return bprop
  496. @bprop_getters.register(P.LogicalOr)
  497. def get_bprop_logical_or(self):
  498. """Grad definition for `LogicalOr` operation."""
  499. def bprop(x, y, out, dout):
  500. return zeros_like(x), zeros_like(y)
  501. return bprop
  502. @bprop_getters.register(P.NPUAllocFloatStatus)
  503. def get_bprop_npu_alloc_float_status(self):
  504. """Grad definition for `NPUAllocFloatStatus` operation."""
  505. def bprop(out, dout):
  506. return ()
  507. return bprop
  508. @bprop_getters.register(P.NPUGetFloatStatus)
  509. def get_bprop_npu_get_float_status(self):
  510. """Grad definition for `NPUGetFloatStatus` operation."""
  511. def bprop(x, out, dout):
  512. return (zeros_like(x),)
  513. return bprop
  514. @bprop_getters.register(P.NPUClearFloatStatus)
  515. def get_bprop_npu_clear_float_status(self):
  516. """Grad definition for `NPUClearFloatStatus` operation."""
  517. def bprop(x, out, dout):
  518. return (zeros_like(x),)
  519. return bprop
  520. @bprop_getters.register(P.AssignAdd)
  521. def get_bprop_assign_add(self):
  522. """Grad definition for `AssignAdd` operation."""
  523. def bprop(x, y, out, dout):
  524. return zeros_like(x), zeros_like(y)
  525. return bprop
  526. @bprop_getters.register(P.AssignSub)
  527. def get_bprop_assign_sub(self):
  528. """Grad definition for `AssignSub` operation."""
  529. def bprop(x, y, out, dout):
  530. return zeros_like(x), zeros_like(y)
  531. return bprop
  532. @bprop_getters.register(P.Sin)
  533. def get_bprop_sin(self):
  534. """Grad definition for `Sin` operation."""
  535. cos = P.Cos()
  536. def bprop(x, out, dout):
  537. dx = dout*cos(x)
  538. return (dx,)
  539. return bprop
  540. @bprop_getters.register(P.Cos)
  541. def get_bprop_cos(self):
  542. """Grad definition for `Cos` operation."""
  543. sin = P.Sin()
  544. neg = P.Neg()
  545. def bprop(x, out, dout):
  546. dx = dout*neg(sin(x))
  547. return (dx,)
  548. return bprop
  549. @bprop_getters.register(P.ACos)
  550. def get_bprop_acos(self):
  551. """Grad definition for `ACos` operation."""
  552. input_grad = G.ACosGrad()
  553. def bprop(x, out, dout):
  554. dx = input_grad(x, dout)
  555. return (dx,)
  556. return bprop
  557. @bprop_getters.register(P.Acosh)
  558. def get_bprop_acosh(self):
  559. """Grad definition for `Acosh` operation."""
  560. input_grad = G.AcoshGrad()
  561. def bprop(x, out, dout):
  562. dx = input_grad(x, dout)
  563. return (dx,)
  564. return bprop
  565. @bprop_getters.register(P.Abs)
  566. def get_bprop_abs(self):
  567. """Grad definition for `Abs` operation."""
  568. abs_grad = G.AbsGrad()
  569. def bprop(x, out, dout):
  570. dx = abs_grad(x, dout)
  571. return (dx,)
  572. return bprop
  573. @bprop_getters.register(P.ScalarCast)
  574. def get_bprop_scalar_cast(self):
  575. """Generate bprop for ScalarCast"""
  576. def bprop(x, t, out, dout):
  577. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  578. return bprop
  579. @bprop_getters.register(P.AddN)
  580. def get_bprop_scalar_addn(self):
  581. """Generate bprop for AddN"""
  582. def bprop(x, out, dout):
  583. dx = ()
  584. for _ in range(len(x)):
  585. dx = dx + (dout,)
  586. return dx
  587. return bprop
  588. @bprop_getters.register(P.Sign)
  589. def get_bprop_sign(self):
  590. """Generate bprop for Sign"""
  591. def bprop(x, out, dout):
  592. return (zeros_like(x),)
  593. return bprop
  594. @bprop_getters.register(P.Round)
  595. def get_bprop_round(self):
  596. """Generate bprop for Round"""
  597. def bprop(x, out, dout):
  598. return (zeros_like(x),)
  599. return bprop
  600. @bprop_getters.register(P.Atan2)
  601. def get_bprop_atan2(self):
  602. """Generate bprop for Atan2"""
  603. square = P.Square()
  604. def bprop(x, y, out, dout):
  605. tmp = dout / (square(x) + square(y))
  606. dx = tmp * y
  607. dy = tmp * (-x)
  608. return (dx, dy)
  609. return bprop