You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 32 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. import mindspore as ms
  19. from mindspore.ops import _selected_grad_ops as SG
  20. from .. import functional as F
  21. from .. import operations as P
  22. from ..operations import _grad_ops as G
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  25. from .grad_base import bprop_getters
  26. from ..primitive import constexpr
  27. from ..composite.multitype_ops import _constexpr_utils as const_utils
  28. shape_op = P.Shape()
  29. reduce_sum = P.ReduceSum()
  30. reshape = P.Reshape()
  31. tile = P.Tile()
  32. is_sub_class = P.IsSubClass()
  33. def binop_grad_common(x, y, dx, dy):
  34. """
  35. Common grad definition for binary operations.
  36. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  37. """
  38. shape_of_x = shape_op(x)
  39. shape_of_y = shape_op(y)
  40. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  41. # if input shape is the same as dout shape, do not need to reduce
  42. reduce_dx = dx
  43. reduce_dy = dy
  44. if rx[0]:
  45. # if dx is scalar whose shape is (), do not need reduce
  46. if shape_op(dx):
  47. dx = reduce_sum(dx, rx[0])
  48. reduce_dx = reshape(dx, shape_of_x)
  49. if rx[1]:
  50. # if dy is scalar whose shape is (), do not need reduce
  51. if shape_op(dy):
  52. dy = reduce_sum(dy, rx[1])
  53. reduce_dy = reshape(dy, shape_of_y)
  54. return reduce_dx, reduce_dy
  55. def _sum_grad(x, axis, dout):
  56. """Grad definition for `Sum` operation."""
  57. input_shape = shape_op(x)
  58. output_shape_kept_dims = reduced_shape(input_shape, axis)
  59. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  60. grad = reshape(dout, output_shape_kept_dims)
  61. return tile(grad, tile_scaling)
  62. def _min_or_max_grad(x, axis, out, dout):
  63. """Grad definition for `Min` and `Max` operations."""
  64. input_shape = shape_op(x)
  65. output_shape_kept_dims = reduced_shape(input_shape, axis)
  66. y = reshape(out, output_shape_kept_dims)
  67. grad = reshape(dout, output_shape_kept_dims)
  68. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  69. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  70. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  71. return indicators / num_selected * grad
  72. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  73. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  74. expand = P.ExpandDims()
  75. x_shape = F.shape(x)
  76. x_dim = len(x_shape)
  77. x_axis = axis
  78. if x_axis < 0:
  79. x_axis = axis + x_dim
  80. onehot_axis = x_axis
  81. depth = x_shape[x_axis]
  82. if keep_dims:
  83. dout_expand = dout[1]
  84. out = op(x)
  85. else:
  86. dout_expand = expand(dout[1], onehot_axis)
  87. if onehot_axis >= len(shape_op(out[0])):
  88. onehot_axis = -1
  89. onehot = P.OneHot(onehot_axis)
  90. type_x = F.dtype(x)
  91. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  92. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  93. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  94. return dx
  95. @bprop_getters.register(P.MatMul)
  96. def bprop_matmul(self):
  97. """Grad definition for `MatMul` operation."""
  98. ta = self.transpose_a
  99. tb = self.transpose_b
  100. mul1 = P.MatMul(transpose_a=(ta and tb),
  101. transpose_b=(ta or (not tb)))
  102. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  103. transpose_b=(ta and tb))
  104. def bprop(x, w, out, dout):
  105. if ta:
  106. dx = mul1(w, dout)
  107. else:
  108. dx = mul1(dout, w)
  109. if tb:
  110. dw = mul2(dout, x)
  111. else:
  112. dw = mul2(x, dout)
  113. return dx, dw
  114. return bprop
  115. @bprop_getters.register(P.BatchMatMul)
  116. def bprop_batchmatmul(self):
  117. """Grad definition for `BatchMatMul` operation."""
  118. ta = self.transpose_a
  119. tb = self.transpose_b
  120. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  121. transpose_b=(ta or (not tb)))
  122. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  123. transpose_b=(ta and tb))
  124. def bprop(x, w, out, dout):
  125. if ta:
  126. dx = mul1(w, dout)
  127. else:
  128. dx = mul1(dout, w)
  129. if tb:
  130. dw = mul2(dout, x)
  131. else:
  132. dw = mul2(x, dout)
  133. return dx, dw
  134. return bprop
  135. @bprop_getters.register(P.Add)
  136. def get_bprop_add(self):
  137. """Grad definition for `Add` operation."""
  138. def bprop(x, y, out, dout):
  139. return binop_grad_common(x, y, dout, dout)
  140. return bprop
  141. @bprop_getters.register(P.TensorAdd)
  142. def get_bprop_tensor_add(self):
  143. """Grad definition for `Add` operation."""
  144. def bprop(x, y, out, dout):
  145. return binop_grad_common(x, y, dout, dout)
  146. return bprop
  147. @bprop_getters.register(P.MatrixInverse)
  148. def get_bprop_matrix_inverse(self):
  149. """Grad definition for `MatrixInverse` operation."""
  150. batchmatmul_a = P.math_ops.BatchMatMul(transpose_a=True)
  151. batchmatmul_b = P.math_ops.BatchMatMul(transpose_b=True)
  152. neg = P.Neg()
  153. def bprop(x, out, dout):
  154. dx = batchmatmul_b(dout, out)
  155. dx = batchmatmul_a(out, dx)
  156. dx = neg(dx)
  157. return dx
  158. return bprop
  159. @bprop_getters.register(P.Neg)
  160. def get_bprop_neg(self):
  161. """Grad definition for `Neg` operation."""
  162. neg_grad = P.Neg()
  163. def bprop(x, out, dout):
  164. dx = neg_grad(dout)
  165. return (dx,)
  166. return bprop
  167. @bprop_getters.register(P.Sub)
  168. def get_bprop_sub(self):
  169. """Grad definition for `Sub` operation."""
  170. neg_func = P.Neg()
  171. def bprop(x, y, out, dout):
  172. return binop_grad_common(x, y, dout, neg_func(dout))
  173. return bprop
  174. @bprop_getters.register(P.Mul)
  175. def get_bprop_mul(self):
  176. """Grad definition for `Mul` operation."""
  177. mul_func = P.Mul()
  178. def bprop(x, y, out, dout):
  179. bc_dx = mul_func(y, dout)
  180. bc_dy = mul_func(x, dout)
  181. return binop_grad_common(x, y, bc_dx, bc_dy)
  182. return bprop
  183. @bprop_getters.register(P.RealDiv)
  184. def get_bprop_real_div(self):
  185. """Grad definition for `RealDiv` operation."""
  186. div_op = P.RealDiv()
  187. neg = P.Neg()
  188. mul_op = P.Mul()
  189. def bprop(x, y, out, dout):
  190. bc_x = div_op(dout, y)
  191. bc_y = neg(mul_op(bc_x, out))
  192. return binop_grad_common(x, y, bc_x, bc_y)
  193. return bprop
  194. @bprop_getters.register(P.Div)
  195. def get_bprop_div(self):
  196. """Grad definition for `Div` operation."""
  197. div_op = P.Div()
  198. neg = P.Neg()
  199. mul_op = P.Mul()
  200. def bprop(x, y, out, dout):
  201. bc_x = div_op(dout, y)
  202. bc_y = neg(mul_op(bc_x, out))
  203. return binop_grad_common(x, y, bc_x, bc_y)
  204. return bprop
  205. @bprop_getters.register(P.DivNoNan)
  206. def get_bprop_div_no_nan(self):
  207. """Grad definition for `DivNoNan` operation."""
  208. div_no_nan_op = P.DivNoNan()
  209. neg = P.Neg()
  210. mul_op = P.Mul()
  211. def bprop(x, y, out, dout):
  212. bc_x = div_no_nan_op(dout, y)
  213. bc_y = neg(mul_op(bc_x, out))
  214. return binop_grad_common(x, y, bc_x, bc_y)
  215. return bprop
  216. @bprop_getters.register(P.Xdivy)
  217. def get_bprop_xdivy(self):
  218. """Grad definition for `Xdivy` operation."""
  219. div_op = P.Xdivy()
  220. def bprop(x, y, out, dout):
  221. x_dtype = F.dtype(x)
  222. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  223. bc_x = div_op(not_zero_x, y) * dout
  224. bc_y = div_op(-x, F.square(y)) * dout
  225. return binop_grad_common(x, y, bc_x, bc_y)
  226. return bprop
  227. @bprop_getters.register(P.Floor)
  228. def get_bprop_floor(self):
  229. """Grad definition for `floor` operation."""
  230. fill_ = P.Fill()
  231. shape_ = P.Shape()
  232. dtype_ = P.DType()
  233. def bprop(x, out, dout):
  234. bc_x = fill_(dtype_(x), shape_(x), 0.)
  235. return (bc_x,)
  236. return bprop
  237. @bprop_getters.register(P.Ceil)
  238. def get_bprop_ceil(self):
  239. """Grad definition for `ceil` operation."""
  240. fill_ = P.Fill()
  241. shape_ = P.Shape()
  242. dtype_ = P.DType()
  243. def bprop(x, out, dout):
  244. bc_x = fill_(dtype_(x), shape_(x), 0.)
  245. return (bc_x,)
  246. return bprop
  247. @bprop_getters.register(P.FloorDiv)
  248. def get_bprop_floordiv(self):
  249. """Grad definition for `FloorDiv` operation."""
  250. def bprop(x, y, out, dout):
  251. return zeros_like(x), zeros_like(y)
  252. return bprop
  253. @bprop_getters.register(P.FloorMod)
  254. def get_bprop_floormod(self):
  255. """Grad definition for `FloorMod` operation."""
  256. def bprop(x, y, out, dout):
  257. bc_x = dout
  258. bc_y = -dout * (x // y)
  259. return binop_grad_common(x, y, bc_x, bc_y)
  260. return bprop
  261. @bprop_getters.register(P.TruncateDiv)
  262. def get_bprop_truncate_div(self):
  263. """Grad definition for `TruncateDiv` operation."""
  264. def bprop(x, y, out, dout):
  265. return zeros_like(x), zeros_like(y)
  266. return bprop
  267. @bprop_getters.register(P.TruncateMod)
  268. def get_bprop_truncate_mod(self):
  269. """Grad definition for `TruncateMod` operation."""
  270. div_op = P.TruncateDiv()
  271. def bprop(x, y, out, dout):
  272. bc_x = dout
  273. bc_y = -dout * div_op(x, y)
  274. return binop_grad_common(x, y, bc_x, bc_y)
  275. return bprop
  276. @bprop_getters.register(P.Mod)
  277. def get_bprop_mod(self):
  278. """Grad definition for `Mod` operation."""
  279. def bprop(x, y, out, dout):
  280. bc_x = dout
  281. bc_y = -dout * (x // y)
  282. return binop_grad_common(x, y, bc_x, bc_y)
  283. return bprop
  284. @bprop_getters.register(P.Square)
  285. def get_bprop_square(self):
  286. """Grad definition for `Square` operation."""
  287. mul_func = P.Mul()
  288. fill_func = P.Fill()
  289. dtype = P.DType()
  290. def bprop(x, out, dout):
  291. temp = mul_func(dout, x)
  292. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  293. return (dx,)
  294. return bprop
  295. @bprop_getters.register(P.SquaredDifference)
  296. def get_bprop_squared_difference(self):
  297. """Grad definition for `SquaredDifference` operation."""
  298. neg = P.Neg()
  299. def bprop(x, y, out, dout):
  300. x_grad = 2 * dout * (x - y)
  301. bc_x = x_grad
  302. bc_y = neg(x_grad)
  303. return binop_grad_common(x, y, bc_x, bc_y)
  304. return bprop
  305. @bprop_getters.register(P.Xlogy)
  306. def get_bprop_xlogy(self):
  307. """Grad definition for `Xlogy` operation."""
  308. log_op = P.Xlogy()
  309. div_op = P.Xdivy()
  310. def bprop(x, y, out, dout):
  311. x_dtype = F.dtype(x)
  312. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  313. bc_x = log_op(not_zero_x, y) * dout
  314. bc_y = div_op(x, y) * dout
  315. return binop_grad_common(x, y, bc_x, bc_y)
  316. return bprop
  317. @bprop_getters.register(P.SquareSumAll)
  318. def get_bprop_square_sum_all(self):
  319. """Grad definition for `Square` operation."""
  320. mul_func = P.Mul()
  321. fill_func = P.Fill()
  322. dtype = P.DType()
  323. def bprop(x, y, out, dout):
  324. temp_x = mul_func(dout[0], x)
  325. temp_y = mul_func(dout[1], y)
  326. dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
  327. dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
  328. return (dx, dy)
  329. return bprop
  330. @bprop_getters.register(P.Sqrt)
  331. def get_bprop_sqrt(self):
  332. """Grad definition for `Sqrt` operation."""
  333. sqrt_grad = G.SqrtGrad()
  334. def bprop(x, out, dout):
  335. dx = sqrt_grad(out, dout)
  336. return (dx,)
  337. return bprop
  338. @bprop_getters.register(G.SqrtGrad)
  339. def get_bprop_sqrt_grad(self):
  340. """Grad definition for `SqrtGrad` operation."""
  341. def bprop(y, grad, out, dout):
  342. gy = dout / y
  343. dy = -gy * out
  344. dgrad = 0.5 * gy
  345. return dy, dgrad
  346. return bprop
  347. @bprop_getters.register(P.Rsqrt)
  348. def get_bprop_rsqrt(self):
  349. """Grad definition for `Rsqrt` operation."""
  350. rsqrt_grad = G.RsqrtGrad()
  351. def bprop(x, out, dout):
  352. dx = rsqrt_grad(out, dout)
  353. return (dx,)
  354. return bprop
  355. @bprop_getters.register(P.Reciprocal)
  356. def get_bprop_reciprocal(self):
  357. """Grad definition for `Reciprocal` operation."""
  358. reciprocal_grad = G.ReciprocalGrad()
  359. def bprop(x, out, dout):
  360. dx = reciprocal_grad(out, dout)
  361. return (dx,)
  362. return bprop
  363. @bprop_getters.register(P.Log)
  364. def get_bprop_log(self):
  365. """Grad definition for `Log` operation."""
  366. reciprocal = P.Reciprocal()
  367. def bprop(x, out, dout):
  368. g = reciprocal(x)
  369. dx = g * dout
  370. return dx, 0
  371. return bprop
  372. @bprop_getters.register(P.Log1p)
  373. def get_bprop_log1p(self):
  374. """Grad definition for `Log1p` operation."""
  375. reciprocal = P.Reciprocal()
  376. def bprop(x, out, dout):
  377. x_1p = x + 1
  378. g = reciprocal(x_1p)
  379. dx = g * dout
  380. return dx, 0
  381. return bprop
  382. @bprop_getters.register(P.Erf)
  383. def get_bprop_erf(self):
  384. """Grad definition for `Erf` operation."""
  385. exp = P.Exp()
  386. square = P.Square()
  387. sqrt = P.Sqrt()
  388. cast = P.Cast()
  389. dtype = P.DType()
  390. def bprop(x, out, dout):
  391. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  392. x_square = square(x)
  393. dx = dout * half_root_pi * exp(-x_square)
  394. return (dx,)
  395. return bprop
  396. @bprop_getters.register(P.Erfc)
  397. def get_bprop_erfc(self):
  398. """Grad definition for `Erfc` operation."""
  399. exp = P.Exp()
  400. square = P.Square()
  401. sqrt = P.Sqrt()
  402. cast = P.Cast()
  403. dtype = P.DType()
  404. def bprop(x, out, dout):
  405. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  406. x_square = square(x)
  407. dx = dout * (-half_root_pi * exp(-x_square))
  408. return (dx,)
  409. return bprop
  410. @bprop_getters.register(P.Pow)
  411. def get_bprop_pow(self):
  412. """Grad definition for `Pow` operation."""
  413. pow_op = P.Pow()
  414. ln = P.Log()
  415. def bprop(x, power, out, dout):
  416. bc_dx = power * pow_op(x, power - 1.0) * dout
  417. x[x < 0] = 1
  418. bc_dpower = out * ln(x) * dout
  419. return binop_grad_common(x, power, bc_dx, bc_dpower)
  420. return bprop
  421. @bprop_getters.register(P.Exp)
  422. def get_bprop_exp(self):
  423. """Grad definition for `Exp` operation."""
  424. exp_ = P.Exp()
  425. def bprop(x, out, dout):
  426. g = exp_(x)
  427. dx = g * dout
  428. return (dx,)
  429. return bprop
  430. @bprop_getters.register(P.Expm1)
  431. def get_bprop_expm1(self):
  432. """Grad definition for `Expm1` operation."""
  433. exp_ = P.Exp()
  434. def bprop(x, out, dout):
  435. g = exp_(x)
  436. dx = g * dout
  437. return (dx,)
  438. return bprop
  439. @bprop_getters.register(P.Minimum)
  440. def get_bprop_minimum(self):
  441. """Grad definition for `Minimum` operation."""
  442. input_grad = SG.MinimumGrad()
  443. def bprop(x, y, out, dout):
  444. dx, dy = input_grad(x, y, dout)
  445. return dx, dy
  446. return bprop
  447. @bprop_getters.register(P.Maximum)
  448. def get_bprop_maximum(self):
  449. """Grad definition for `Maximum` operation."""
  450. input_grad = SG.MaximumGrad()
  451. def bprop(x, y, out, dout):
  452. dx, dy = input_grad(x, y, dout)
  453. return dx, dy
  454. return bprop
  455. @bprop_getters.register(P.ReduceSum)
  456. def get_bprop_reducesum(self):
  457. """Grad definition for `ReduceSum` operation."""
  458. def bprop(x, axis, out, dout):
  459. dx = _sum_grad(x, axis, dout)
  460. return dx, zeros_like(axis)
  461. return bprop
  462. @bprop_getters.register(P.CumSum)
  463. def get_bprop_cumsum(self):
  464. """Grad definition for `CumSum` operation."""
  465. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  466. def bprop(x, axis, out, dout):
  467. return cumsum(dout, axis), zeros_like(axis)
  468. return bprop
  469. @constexpr
  470. def _split_shape_index(input_shape, axis):
  471. """Calculate reduce_prod grad transpose indices and perm shape."""
  472. rank = len(input_shape)
  473. if isinstance(axis, int):
  474. axis = tuple([axis])
  475. reduction_indices = tuple([(i + rank) % rank for i in axis])
  476. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  477. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  478. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  479. perm = reduction_indices + other_indices
  480. return tuple([reduced_num, other_num]), perm
  481. @constexpr
  482. def _invert_permutation(perm):
  483. """Calculate invert permutation."""
  484. out = [0] * len(perm)
  485. for i, value in enumerate(perm):
  486. out[value] = i
  487. return tuple(out)
  488. @bprop_getters.register(P.ReduceProd)
  489. def get_bprop_reduceprod(self):
  490. """Grad definition for `ReduceProd` operation."""
  491. transpose = P.Transpose()
  492. left_cumprod = P.CumProd(exclusive=True)
  493. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  494. def bprop(x, axis, out, dout):
  495. """Grad definition for `Product` operation."""
  496. # Expand dout to full input shape
  497. input_shape = shape_op(x)
  498. output_shape_kept_dims = reduced_shape(input_shape, axis)
  499. dout = reshape(dout, output_shape_kept_dims)
  500. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  501. grad = tile(dout, tile_scaling)
  502. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  503. pack_shape, perm = _split_shape_index(input_shape, axis)
  504. permuted = transpose(x, perm)
  505. permuted_shape = shape_op(permuted)
  506. reshaped = reshape(permuted, pack_shape)
  507. # Calculate product, leaving out the current entry
  508. left = left_cumprod(reshaped, 0)
  509. right = right_cumprod(reshaped, 0)
  510. y = reshape(left * right, permuted_shape)
  511. # Invert the transpose and reshape operations.
  512. # Make sure to set the statically known shape information through a reshape.
  513. out = transpose(y, _invert_permutation(perm)) * grad
  514. dx = reshape(out, input_shape)
  515. return dx, zeros_like(axis)
  516. return bprop
  517. @bprop_getters.register(P.CumProd)
  518. def get_bprop_cumprod(self):
  519. """Grad definition for `CumProd` operation."""
  520. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  521. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  522. def bprop(x, axis, out, dout):
  523. """Grad definition for `Product` operation."""
  524. # This will fails when x contains 0
  525. prod = cumprod(x, axis)
  526. out = cumsum(prod * dout, axis)
  527. return out / x, zeros_like(axis)
  528. return bprop
  529. @bprop_getters.register(P.ReduceAll)
  530. def get_bprop_reduceall(self):
  531. """Grad definition for `ReduceAll` operation."""
  532. def bprop(x, axis, out, dout):
  533. return zeros_like(x), zeros_like(axis)
  534. return bprop
  535. @bprop_getters.register(P.ReduceAny)
  536. def get_bprop_reduceany(self):
  537. """Grad definition for `ReduceAny` operation."""
  538. def bprop(x, axis, out, dout):
  539. return zeros_like(x), zeros_like(axis)
  540. return bprop
  541. @bprop_getters.register(P.ReduceMax)
  542. def get_bprop_reducemax(self):
  543. """Grad definition for `Max` operation."""
  544. def bprop(x, axis, out, dout):
  545. dx = _min_or_max_grad(x, axis, out, dout)
  546. return (dx, zeros_like(axis))
  547. return bprop
  548. @bprop_getters.register(P.ArgMaxWithValue)
  549. def get_bprop_argmaxwithvalue(self):
  550. """Grad definition for `ArgMaxWithValue` operation."""
  551. axis = self.axis
  552. keep_dims = self.keep_dims
  553. op = P.ArgMaxWithValue(axis)
  554. def bprop(x, out, dout):
  555. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  556. return (dx,)
  557. return bprop
  558. @bprop_getters.register(P.ReduceMin)
  559. def get_bprop_reducemin(self):
  560. """Grad definition for `ReduceMin` operation."""
  561. def bprop(x, axis, out, dout):
  562. dx = _min_or_max_grad(x, axis, out, dout)
  563. return (dx, zeros_like(axis))
  564. return bprop
  565. @bprop_getters.register(P.ArgMinWithValue)
  566. def get_bprop_argminwithvalue(self):
  567. """Generate bprop for ArgMinWithValue"""
  568. axis = self.axis
  569. keep_dims = self.keep_dims
  570. op = P.ArgMinWithValue(axis)
  571. def bprop(x, out, dout):
  572. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  573. return (dx,)
  574. return bprop
  575. @bprop_getters.register(P.ReduceMean)
  576. def get_bprop_reduce_mean(self):
  577. """Grad definition for `ReduceMean` operation."""
  578. div_op = P.RealDiv()
  579. cast = P.Cast()
  580. dtype = P.DType()
  581. def bprop(x, axis, out, dout):
  582. grad = _sum_grad(x, axis, dout)
  583. div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out))
  584. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  585. return dx, zeros_like(axis)
  586. return bprop
  587. @bprop_getters.register(P.IsFinite)
  588. def get_bprop_isfinite(self):
  589. """Grad definition for `IsFinite` operation."""
  590. def bprop(x, out, dout):
  591. return (zeros_like(x),)
  592. return bprop
  593. @bprop_getters.register(P.Equal)
  594. def get_bprop_equal(self):
  595. """Grad definition for `Equal` operation."""
  596. def bprop(x, y, out, dout):
  597. return zeros_like(x), zeros_like(y)
  598. return bprop
  599. @bprop_getters.register(P.NotEqual)
  600. def get_bprop_not_equal(self):
  601. """Grad definition for `NotEqual` operation."""
  602. def bprop(x, y, out, dout):
  603. return zeros_like(x), zeros_like(y)
  604. return bprop
  605. @bprop_getters.register(P.ApproximateEqual)
  606. def get_bprop_approximate_equal(self):
  607. """Grad definition for `ApproximateEqual` operation."""
  608. def bprop(x, y, out, dout):
  609. return zeros_like(x), zeros_like(y)
  610. return bprop
  611. @bprop_getters.register(P.Greater)
  612. def get_bprop_greater(self):
  613. """Grad definition for `Greater` operation."""
  614. def bprop(x, y, out, dout):
  615. return zeros_like(x), zeros_like(y)
  616. return bprop
  617. @bprop_getters.register(P.GreaterEqual)
  618. def get_bprop_greater_equal(self):
  619. """Grad definition for `GreaterEqual` operation."""
  620. def bprop(x, y, out, dout):
  621. return zeros_like(x), zeros_like(y)
  622. return bprop
  623. @bprop_getters.register(P.Less)
  624. def get_bprop_less(self):
  625. """Grad definition for `Less` operation."""
  626. def bprop(x, y, out, dout):
  627. return zeros_like(x), zeros_like(y)
  628. return bprop
  629. @bprop_getters.register(P.LessEqual)
  630. def get_bprop_less_equal(self):
  631. """Grad definition for `LessEqual` operation."""
  632. def bprop(x, y, out, dout):
  633. return zeros_like(x), zeros_like(y)
  634. return bprop
  635. @bprop_getters.register(P.LogicalNot)
  636. def get_bprop_logical_not(self):
  637. """Grad definition for `LogicalNot` operation."""
  638. def bprop(x, out, dout):
  639. return (zeros_like(x),)
  640. return bprop
  641. @bprop_getters.register(P.LogicalAnd)
  642. def get_bprop_logical_and(self):
  643. """Grad definition for `LogicalAnd` operation."""
  644. def bprop(x, y, out, dout):
  645. return zeros_like(x), zeros_like(y)
  646. return bprop
  647. @bprop_getters.register(P.LogicalOr)
  648. def get_bprop_logical_or(self):
  649. """Grad definition for `LogicalOr` operation."""
  650. def bprop(x, y, out, dout):
  651. return zeros_like(x), zeros_like(y)
  652. return bprop
  653. @bprop_getters.register(P.NPUAllocFloatStatus)
  654. def get_bprop_npu_alloc_float_status(self):
  655. """Grad definition for `NPUAllocFloatStatus` operation."""
  656. def bprop(out, dout):
  657. return ()
  658. return bprop
  659. @bprop_getters.register(P.NPUGetFloatStatus)
  660. def get_bprop_npu_get_float_status(self):
  661. """Grad definition for `NPUGetFloatStatus` operation."""
  662. def bprop(x, out, dout):
  663. return (zeros_like(x),)
  664. return bprop
  665. @bprop_getters.register(P.NPUClearFloatStatus)
  666. def get_bprop_npu_clear_float_status(self):
  667. """Grad definition for `NPUClearFloatStatus` operation."""
  668. def bprop(x, out, dout):
  669. return (zeros_like(x),)
  670. return bprop
  671. @bprop_getters.register(P.AssignAdd)
  672. def get_bprop_assign_add(self):
  673. """Grad definition for `AssignAdd` operation."""
  674. def bprop(x, y, out, dout):
  675. return zeros_like(x), zeros_like(y)
  676. return bprop
  677. @bprop_getters.register(P.AssignSub)
  678. def get_bprop_assign_sub(self):
  679. """Grad definition for `AssignSub` operation."""
  680. def bprop(x, y, out, dout):
  681. return zeros_like(x), zeros_like(y)
  682. return bprop
  683. @bprop_getters.register(P.Sin)
  684. def get_bprop_sin(self):
  685. """Grad definition for `Sin` operation."""
  686. cos = P.Cos()
  687. def bprop(x, out, dout):
  688. dx = dout * cos(x)
  689. return (dx,)
  690. return bprop
  691. @bprop_getters.register(P.Asin)
  692. def get_bprop_asin(self):
  693. """Grad definition for `Asin` operation."""
  694. input_grad = G.AsinGrad()
  695. def bprop(x, out, dout):
  696. dx = input_grad(x, dout)
  697. return (dx,)
  698. return bprop
  699. @bprop_getters.register(G.AsinGrad)
  700. def get_bprop_asin_grad(self):
  701. """Grad definition for `AsinGrad` operation."""
  702. input_grad = G.AsinGrad()
  703. p_pow = P.Pow()
  704. def bprop(x, grad, out, dout):
  705. d2x = dout * grad * x * p_pow((1 - x * x), - 1.5)
  706. ddy = input_grad(x, dout)
  707. return (d2x, ddy)
  708. return bprop
  709. @bprop_getters.register(P.Asinh)
  710. def get_bprop_asinh(self):
  711. """Grad definition for `Asinh` operation."""
  712. input_grad = G.AsinhGrad()
  713. def bprop(x, out, dout):
  714. dx = input_grad(out, dout)
  715. return (dx,)
  716. return bprop
  717. @bprop_getters.register(G.AsinhGrad)
  718. def get_bprop_asinh_grad(self):
  719. """Grad definition for `AsinhGrad` operation."""
  720. input_grad = G.AsinhGrad()
  721. tanh = P.Tanh()
  722. def bprop(y, grad, out, dout):
  723. dy = dout * out * -1.0 * tanh(y)
  724. dgrad = input_grad(y, dout)
  725. return dy, dgrad
  726. return bprop
  727. @bprop_getters.register(P.Sinh)
  728. def get_bprop_sinh(self):
  729. """Grad definition for `Sinh` operation."""
  730. cosh = P.Cosh()
  731. def bprop(x, out, dout):
  732. dx = cosh(x) * dout
  733. return (dx,)
  734. return bprop
  735. @bprop_getters.register(P.Cos)
  736. def get_bprop_cos(self):
  737. """Grad definition for `Cos` operation."""
  738. sin = P.Sin()
  739. neg = P.Neg()
  740. def bprop(x, out, dout):
  741. dx = dout * neg(sin(x))
  742. return (dx,)
  743. return bprop
  744. @bprop_getters.register(P.ACos)
  745. def get_bprop_acos(self):
  746. """Grad definition for `ACos` operation."""
  747. input_grad = G.ACosGrad()
  748. def bprop(x, out, dout):
  749. dx = input_grad(x, dout)
  750. return (dx,)
  751. return bprop
  752. @bprop_getters.register(G.ACosGrad)
  753. def get_bprop_acos_grad(self):
  754. """Grad definition for `ACosGrad` operation."""
  755. input_grad = G.ACosGrad()
  756. p_pow = P.Pow()
  757. def bprop(x, grad, out, dout):
  758. d2x = -dout * grad * x * p_pow((1 - x * x), - 1.5)
  759. ddy = input_grad(x, dout)
  760. return (d2x, ddy)
  761. return bprop
  762. @bprop_getters.register(P.Acosh)
  763. def get_bprop_acosh(self):
  764. """Grad definition for `Acosh` operation."""
  765. input_grad = G.AcoshGrad()
  766. def bprop(x, out, dout):
  767. dx = input_grad(out, dout)
  768. return (dx,)
  769. return bprop
  770. @bprop_getters.register(G.AcoshGrad)
  771. def get_bprop_acosh_grad(self):
  772. """Grad definition for `AcoshGrad` operation."""
  773. input_grad = G.AcoshGrad()
  774. tanh = P.Tanh()
  775. def bprop(y, grad, out, dout):
  776. dy = dout * out * -1.0 / tanh(y)
  777. dgrad = input_grad(y, dout)
  778. return dy, dgrad
  779. return bprop
  780. @bprop_getters.register(P.Cosh)
  781. def get_bprop_cosh(self):
  782. """Grad definition for `Cosh` operation."""
  783. sinh = P.Sinh()
  784. def bprop(x, out, dout):
  785. dx = sinh(x) * dout
  786. return (dx,)
  787. return bprop
  788. @bprop_getters.register(P.Abs)
  789. def get_bprop_abs(self):
  790. """Grad definition for `Abs` operation."""
  791. abs_grad = SG.AbsGrad()
  792. def bprop(x, out, dout):
  793. dx = abs_grad(x, dout)
  794. return (dx,)
  795. return bprop
  796. @bprop_getters.register(P.ScalarCast)
  797. def get_bprop_scalar_cast(self):
  798. """Generate bprop for ScalarCast"""
  799. def bprop(x, t, out, dout):
  800. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  801. return bprop
  802. @bprop_getters.register(P.AccumulateNV2)
  803. def get_bprop_scalar_accumulatenv2(self):
  804. """Generate bprop for AccumulateNV2"""
  805. def bprop(x, out, dout):
  806. dx = ()
  807. for _ in range(len(x)):
  808. dx = dx + (dout,)
  809. return (dx,)
  810. return bprop
  811. @bprop_getters.register(P.AddN)
  812. def get_bprop_scalar_addn(self):
  813. """Generate bprop for AddN"""
  814. def bprop(x, out, dout):
  815. if is_sub_class(F.typeof(x), ms.list_):
  816. dx = []
  817. for _ in range(len(x)):
  818. dx.append(dout)
  819. return (dx,)
  820. dx = ()
  821. for _ in range(len(x)):
  822. dx = dx + (dout,)
  823. return (dx,)
  824. return bprop
  825. @bprop_getters.register(P.Sign)
  826. def get_bprop_sign(self):
  827. """Generate bprop for Sign"""
  828. def bprop(x, out, dout):
  829. return (zeros_like(x),)
  830. return bprop
  831. @bprop_getters.register(P.Round)
  832. def get_bprop_round(self):
  833. """Generate bprop for Round"""
  834. def bprop(x, out, dout):
  835. return (zeros_like(x),)
  836. return bprop
  837. @bprop_getters.register(P.Atan2)
  838. def get_bprop_atan2(self):
  839. """Generate bprop for Atan2"""
  840. square = P.Square()
  841. def bprop(x, y, out, dout):
  842. tmp = dout / (square(x) + square(y))
  843. bc_dx = tmp * y
  844. bc_dy = tmp * (-x)
  845. return binop_grad_common(x, y, bc_dx, bc_dy)
  846. return bprop
  847. @bprop_getters.register(P.BesselI0e)
  848. def get_bprop_bessel_i0e(self):
  849. """Generate bprop for BesselI0e"""
  850. sign = P.Sign()
  851. bessel_i1e = P.BesselI1e()
  852. def bprop(x, out, dout):
  853. dx = dout * (bessel_i1e(x) - sign(x) * out)
  854. return (dx,)
  855. return bprop
  856. @bprop_getters.register(P.Atan)
  857. def get_bprop_atan(self):
  858. """Grad definition for `Atan` operation."""
  859. input_grad = G.AtanGrad()
  860. def bprop(x, out, dout):
  861. dx = input_grad(x, dout)
  862. return (dx,)
  863. return bprop
  864. @bprop_getters.register(G.AtanGrad)
  865. def get_bprop_atan_grad(self):
  866. """Grad definition for `AtanGrad` operation."""
  867. input_grad = G.AtanGrad()
  868. def bprop(x, grad, out, dout):
  869. dgrad = input_grad(x, dout)
  870. dx = out * dgrad * -2.0 * x
  871. return dx, dgrad
  872. return bprop
  873. @bprop_getters.register(P.Tan)
  874. def get_bprop_tan(self):
  875. """Grad definition for `Tan` operation."""
  876. reciprocal = P.Reciprocal()
  877. square = P.Square()
  878. cos = P.Cos()
  879. def bprop(x, out, dout):
  880. cosx = cos(x)
  881. secx2 = square(reciprocal(cosx))
  882. dx = secx2 * dout
  883. return (dx,)
  884. return bprop
  885. @bprop_getters.register(P.BesselI1e)
  886. def get_bprop_bessel_i1e(self):
  887. """Generate bprop for BesselI1e"""
  888. sign = P.Sign()
  889. bessel_i0e = P.BesselI0e()
  890. less = P.Less()
  891. select = P.Select()
  892. reciprocal = P.Reciprocal()
  893. cast = P.Cast()
  894. dtype = P.DType()
  895. abs_ops = P.Abs()
  896. def bprop(x, out, dout):
  897. zeros = zeros_like(x)
  898. np_eps = const_utils.get_np_eps(dtype(x))
  899. eps = cast(np_eps, dtype(x))
  900. x_is_valid = less(eps, abs_ops(x))
  901. x_safe = select(x_is_valid, x, eps + zeros)
  902. tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
  903. dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
  904. return (dx,)
  905. return bprop
  906. @bprop_getters.register(P.Atanh)
  907. def get_bprop_atanh(self):
  908. """Grad definition for `Atanh` operation."""
  909. power = P.Pow()
  910. div = P.Div()
  911. def bprop(x, out, dout):
  912. tmp = 1 - power(x, 2)
  913. dx = div(1, tmp) * dout
  914. return (dx,)
  915. return bprop
  916. @bprop_getters.register(P.Inv)
  917. def get_bprop_inv(self):
  918. """Grad definition for 'Inv' operation"""
  919. inv_grad = G.InvGrad()
  920. def bprop(x, out, dout):
  921. dx = inv_grad(out, dout)
  922. return (dx,)
  923. return bprop
  924. @bprop_getters.register(P.LinSpace)
  925. def get_bprop_lin_space(self):
  926. """Grad definition for `LinSpace` operation."""
  927. def bprop(start, stop, num, out, dout):
  928. return zeros_like(start), zeros_like(stop), zeros_like(num)
  929. return bprop