You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_math_ops.py 32 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of math related operations."""
  16. from functools import reduce
  17. import numpy as np
  18. import mindspore as ms
  19. from mindspore import nn
  20. from .. import functional as F
  21. from .. import operations as P
  22. from ..operations import _grad_ops as G
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
  25. from .grad_base import bprop_getters
  26. from ..primitive import constexpr
  27. from ..composite.multitype_ops import _constexpr_utils as const_utils
  28. shape_op = P.Shape()
  29. reduce_sum = P.ReduceSum()
  30. reshape = P.Reshape()
  31. tile = P.Tile()
  32. is_sub_class = P.IsSubClass()
  33. def binop_grad_common(x, y, dx, dy):
  34. """
  35. Common grad definition for binary operations.
  36. The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
  37. """
  38. shape_of_x = shape_op(x)
  39. shape_of_y = shape_op(y)
  40. rx = broadcast_gradient_args(shape_of_x, shape_of_y)
  41. # if input shape is the same as dout shape, do not need to reduce
  42. reduce_dx = dx
  43. reduce_dy = dy
  44. if rx[0]:
  45. # if dx is scalar whose shape is (), do not need reduce
  46. if shape_op(dx):
  47. dx = reduce_sum(dx, rx[0])
  48. reduce_dx = reshape(dx, shape_of_x)
  49. if rx[1]:
  50. # if dy is scalar whose shape is (), do not need reduce
  51. if shape_op(dy):
  52. dy = reduce_sum(dy, rx[1])
  53. reduce_dy = reshape(dy, shape_of_y)
  54. return reduce_dx, reduce_dy
  55. def _sum_grad(x, axis, dout):
  56. """Grad definition for `Sum` operation."""
  57. input_shape = shape_op(x)
  58. output_shape_kept_dims = reduced_shape(input_shape, axis)
  59. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  60. grad = reshape(dout, output_shape_kept_dims)
  61. return tile(grad, tile_scaling)
  62. def _min_or_max_grad(x, axis, out, dout):
  63. """Grad definition for `Min` and `Max` operations."""
  64. input_shape = shape_op(x)
  65. output_shape_kept_dims = reduced_shape(input_shape, axis)
  66. y = reshape(out, output_shape_kept_dims)
  67. grad = reshape(dout, output_shape_kept_dims)
  68. indicators = F.cast(F.equal(y, x), F.dtype(grad))
  69. min_num = F.cast(F.scalar_to_array(1e-24), F.dtype(grad))
  70. num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
  71. return indicators / num_selected * grad
  72. def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
  73. """ArgMinWiwhValue and ArgMaxWithValue grad."""
  74. expand = P.ExpandDims()
  75. x_shape = F.shape(x)
  76. x_dim = len(x_shape)
  77. x_axis = axis
  78. if x_axis < 0:
  79. x_axis = axis + x_dim
  80. onehot_axis = x_axis
  81. depth = x_shape[x_axis]
  82. if keep_dims:
  83. dout_expand = dout[1]
  84. out = op(x)
  85. else:
  86. dout_expand = expand(dout[1], onehot_axis)
  87. if onehot_axis >= len(shape_op(out[0])):
  88. onehot_axis = -1
  89. onehot = P.OneHot(onehot_axis)
  90. type_x = F.dtype(x)
  91. on_value = F.cast(F.scalar_to_array(1.0), type_x)
  92. off_value = F.cast(F.scalar_to_array(0.0), type_x)
  93. dx = dout_expand * onehot(out[0], depth, on_value, off_value)
  94. return dx
  95. @bprop_getters.register(P.MatMul)
  96. def bprop_matmul(self):
  97. """Grad definition for `MatMul` operation."""
  98. ta = self.transpose_a
  99. tb = self.transpose_b
  100. mul1 = P.MatMul(transpose_a=(ta and tb),
  101. transpose_b=(ta or (not tb)))
  102. mul2 = P.MatMul(transpose_a=((not ta) or tb),
  103. transpose_b=(ta and tb))
  104. def bprop(x, w, out, dout):
  105. if ta:
  106. dx = mul1(w, dout)
  107. else:
  108. dx = mul1(dout, w)
  109. if tb:
  110. dw = mul2(dout, x)
  111. else:
  112. dw = mul2(x, dout)
  113. return dx, dw
  114. return bprop
  115. @bprop_getters.register(P.BatchMatMul)
  116. def bprop_batchmatmul(self):
  117. """Grad definition for `BatchMatMul` operation."""
  118. ta = self.transpose_a
  119. tb = self.transpose_b
  120. mul1 = P.BatchMatMul(transpose_a=(ta and tb),
  121. transpose_b=(ta or (not tb)))
  122. mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
  123. transpose_b=(ta and tb))
  124. def bprop(x, w, out, dout):
  125. if ta:
  126. dx = mul1(w, dout)
  127. else:
  128. dx = mul1(dout, w)
  129. if tb:
  130. dw = mul2(dout, x)
  131. else:
  132. dw = mul2(x, dout)
  133. return dx, dw
  134. return bprop
  135. @bprop_getters.register(P.Add)
  136. def get_bprop_add(self):
  137. """Grad definition for `Add` operation."""
  138. def bprop(x, y, out, dout):
  139. return binop_grad_common(x, y, dout, dout)
  140. return bprop
  141. @bprop_getters.register(P.TensorAdd)
  142. def get_bprop_tensor_add(self):
  143. """Grad definition for `Add` operation."""
  144. def bprop(x, y, out, dout):
  145. return binop_grad_common(x, y, dout, dout)
  146. return bprop
  147. @bprop_getters.register(P.MatrixInverse)
  148. def get_bprop_matrix_inverse(self):
  149. """Grad definition for `MatrixInverse` operation."""
  150. matmul_x1 = nn.MatMul(transpose_x1=True)
  151. matmul_x2 = nn.MatMul(transpose_x2=True)
  152. neg = P.Neg()
  153. def bprop(x, out, dout):
  154. dx = matmul_x2(dout, out)
  155. dx = matmul_x1(out, dx)
  156. dx = neg(dx)
  157. return (dx,)
  158. return bprop
  159. @bprop_getters.register(P.Neg)
  160. def get_bprop_neg(self):
  161. """Grad definition for `Neg` operation."""
  162. neg_grad = P.Neg()
  163. def bprop(x, out, dout):
  164. dx = neg_grad(dout)
  165. return (dx,)
  166. return bprop
  167. @bprop_getters.register(P.Sub)
  168. def get_bprop_sub(self):
  169. """Grad definition for `Sub` operation."""
  170. neg_func = P.Neg()
  171. def bprop(x, y, out, dout):
  172. return binop_grad_common(x, y, dout, neg_func(dout))
  173. return bprop
  174. @bprop_getters.register(P.Mul)
  175. def get_bprop_mul(self):
  176. """Grad definition for `Mul` operation."""
  177. mul_func = P.Mul()
  178. def bprop(x, y, out, dout):
  179. bc_dx = mul_func(y, dout)
  180. bc_dy = mul_func(x, dout)
  181. return binop_grad_common(x, y, bc_dx, bc_dy)
  182. return bprop
  183. @bprop_getters.register(P.RealDiv)
  184. def get_bprop_real_div(self):
  185. """Grad definition for `RealDiv` operation."""
  186. div_op = P.RealDiv()
  187. neg = P.Neg()
  188. mul_op = P.Mul()
  189. def bprop(x, y, out, dout):
  190. bc_x = div_op(dout, y)
  191. bc_y = neg(mul_op(bc_x, out))
  192. return binop_grad_common(x, y, bc_x, bc_y)
  193. return bprop
  194. @bprop_getters.register(P.Div)
  195. def get_bprop_div(self):
  196. """Grad definition for `Div` operation."""
  197. div_op = P.Div()
  198. neg = P.Neg()
  199. mul_op = P.Mul()
  200. def bprop(x, y, out, dout):
  201. bc_x = div_op(dout, y)
  202. bc_y = neg(mul_op(bc_x, out))
  203. return binop_grad_common(x, y, bc_x, bc_y)
  204. return bprop
  205. @bprop_getters.register(P.DivNoNan)
  206. def get_bprop_div_no_nan(self):
  207. """Grad definition for `DivNoNan` operation."""
  208. div_no_nan_op = P.DivNoNan()
  209. neg = P.Neg()
  210. mul_op = P.Mul()
  211. def bprop(x, y, out, dout):
  212. bc_x = div_no_nan_op(dout, y)
  213. bc_y = neg(mul_op(bc_x, out))
  214. return binop_grad_common(x, y, bc_x, bc_y)
  215. return bprop
  216. @bprop_getters.register(P.Xdivy)
  217. def get_bprop_xdivy(self):
  218. """Grad definition for `Xdivy` operation."""
  219. div_op = P.Xdivy()
  220. def bprop(x, y, out, dout):
  221. x_dtype = F.dtype(x)
  222. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  223. bc_x = div_op(not_zero_x, y) * dout
  224. bc_y = div_op(-x, F.square(y)) * dout
  225. return binop_grad_common(x, y, bc_x, bc_y)
  226. return bprop
  227. @bprop_getters.register(P.Floor)
  228. def get_bprop_floor(self):
  229. """Grad definition for `floor` operation."""
  230. fill_ = P.Fill()
  231. shape_ = P.Shape()
  232. dtype_ = P.DType()
  233. def bprop(x, out, dout):
  234. bc_x = fill_(dtype_(x), shape_(x), 0.)
  235. return (bc_x,)
  236. return bprop
  237. @bprop_getters.register(P.Ceil)
  238. def get_bprop_ceil(self):
  239. """Grad definition for `ceil` operation."""
  240. fill_ = P.Fill()
  241. shape_ = P.Shape()
  242. dtype_ = P.DType()
  243. def bprop(x, out, dout):
  244. bc_x = fill_(dtype_(x), shape_(x), 0.)
  245. return (bc_x,)
  246. return bprop
  247. @bprop_getters.register(P.FloorDiv)
  248. def get_bprop_floordiv(self):
  249. """Grad definition for `FloorDiv` operation."""
  250. def bprop(x, y, out, dout):
  251. return zeros_like(x), zeros_like(y)
  252. return bprop
  253. @bprop_getters.register(P.FloorMod)
  254. def get_bprop_floormod(self):
  255. """Grad definition for `FloorMod` operation."""
  256. def bprop(x, y, out, dout):
  257. bc_x = dout
  258. bc_y = -dout * (x // y)
  259. return binop_grad_common(x, y, bc_x, bc_y)
  260. return bprop
  261. @bprop_getters.register(P.TruncateDiv)
  262. def get_bprop_truncate_div(self):
  263. """Grad definition for `TruncateDiv` operation."""
  264. def bprop(x, y, out, dout):
  265. return zeros_like(x), zeros_like(y)
  266. return bprop
  267. @bprop_getters.register(P.TruncateMod)
  268. def get_bprop_truncate_mod(self):
  269. """Grad definition for `TruncateMod` operation."""
  270. div_op = P.TruncateDiv()
  271. def bprop(x, y, out, dout):
  272. bc_x = dout
  273. bc_y = -dout * div_op(x, y)
  274. return binop_grad_common(x, y, bc_x, bc_y)
  275. return bprop
  276. @bprop_getters.register(P.Mod)
  277. def get_bprop_mod(self):
  278. """Grad definition for `Mod` operation."""
  279. def bprop(x, y, out, dout):
  280. bc_x = dout
  281. bc_y = -dout * (x // y)
  282. return binop_grad_common(x, y, bc_x, bc_y)
  283. return bprop
  284. @bprop_getters.register(P.Square)
  285. def get_bprop_square(self):
  286. """Grad definition for `Square` operation."""
  287. mul_func = P.Mul()
  288. fill_func = P.Fill()
  289. dtype = P.DType()
  290. def bprop(x, out, dout):
  291. temp = mul_func(dout, x)
  292. dx = mul_func(fill_func(dtype(temp), shape_op(x), 2.0), temp)
  293. return (dx,)
  294. return bprop
  295. @bprop_getters.register(P.SquaredDifference)
  296. def get_bprop_squared_difference(self):
  297. """Grad definition for `SquaredDifference` operation."""
  298. neg = P.Neg()
  299. def bprop(x, y, out, dout):
  300. x_grad = 2 * dout * (x - y)
  301. bc_x = x_grad
  302. bc_y = neg(x_grad)
  303. return binop_grad_common(x, y, bc_x, bc_y)
  304. return bprop
  305. @bprop_getters.register(P.Xlogy)
  306. def get_bprop_xlogy(self):
  307. """Grad definition for `Xlogy` operation."""
  308. log_op = P.Xlogy()
  309. div_op = P.Xdivy()
  310. def bprop(x, y, out, dout):
  311. x_dtype = F.dtype(x)
  312. not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
  313. bc_x = log_op(not_zero_x, y) * dout
  314. bc_y = div_op(x, y) * dout
  315. return binop_grad_common(x, y, bc_x, bc_y)
  316. return bprop
  317. @bprop_getters.register(P.SquareSumAll)
  318. def get_bprop_square_sum_all(self):
  319. """Grad definition for `Square` operation."""
  320. mul_func = P.Mul()
  321. fill_func = P.Fill()
  322. dtype = P.DType()
  323. def bprop(x, y, out, dout):
  324. temp_x = mul_func(dout[0], x)
  325. temp_y = mul_func(dout[1], y)
  326. dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
  327. dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
  328. return (dx, dy)
  329. return bprop
  330. @bprop_getters.register(P.Sqrt)
  331. def get_bprop_sqrt(self):
  332. """Grad definition for `Sqrt` operation."""
  333. sqrt_grad = G.SqrtGrad()
  334. def bprop(x, out, dout):
  335. dx = sqrt_grad(out, dout)
  336. return (dx,)
  337. return bprop
  338. @bprop_getters.register(G.SqrtGrad)
  339. def get_bprop_sqrt_grad(self):
  340. """Grad definition for `SqrtGrad` operation."""
  341. def bprop(y, grad, out, dout):
  342. gy = dout / y
  343. dy = -gy * out
  344. dgrad = 0.5 * gy
  345. return dy, dgrad
  346. return bprop
  347. @bprop_getters.register(P.Rsqrt)
  348. def get_bprop_rsqrt(self):
  349. """Grad definition for `Rsqrt` operation."""
  350. rsqrt_grad = G.RsqrtGrad()
  351. def bprop(x, out, dout):
  352. dx = rsqrt_grad(out, dout)
  353. return (dx,)
  354. return bprop
  355. @bprop_getters.register(P.Reciprocal)
  356. def get_bprop_reciprocal(self):
  357. """Grad definition for `Reciprocal` operation."""
  358. reciprocal_grad = G.ReciprocalGrad()
  359. def bprop(x, out, dout):
  360. dx = reciprocal_grad(out, dout)
  361. return (dx,)
  362. return bprop
  363. @bprop_getters.register(P.Log)
  364. def get_bprop_log(self):
  365. """Grad definition for `Log` operation."""
  366. reciprocal = P.Reciprocal()
  367. def bprop(x, out, dout):
  368. g = reciprocal(x)
  369. dx = g * dout
  370. return dx, 0
  371. return bprop
  372. @bprop_getters.register(P.Log1p)
  373. def get_bprop_log1p(self):
  374. """Grad definition for `Log1p` operation."""
  375. reciprocal = P.Reciprocal()
  376. def bprop(x, out, dout):
  377. x_1p = x + 1
  378. g = reciprocal(x_1p)
  379. dx = g * dout
  380. return dx, 0
  381. return bprop
  382. @bprop_getters.register(P.Erf)
  383. def get_bprop_erf(self):
  384. """Grad definition for `Erf` operation."""
  385. exp = P.Exp()
  386. square = P.Square()
  387. sqrt = P.Sqrt()
  388. cast = P.Cast()
  389. dtype = P.DType()
  390. def bprop(x, out, dout):
  391. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  392. x_square = square(x)
  393. dx = dout * half_root_pi * exp(-x_square)
  394. return (dx,)
  395. return bprop
  396. @bprop_getters.register(P.Erfc)
  397. def get_bprop_erfc(self):
  398. """Grad definition for `Erfc` operation."""
  399. exp = P.Exp()
  400. square = P.Square()
  401. sqrt = P.Sqrt()
  402. cast = P.Cast()
  403. dtype = P.DType()
  404. def bprop(x, out, dout):
  405. half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
  406. x_square = square(x)
  407. dx = dout * (-half_root_pi * exp(-x_square))
  408. return (dx,)
  409. return bprop
  410. @bprop_getters.register(P.Pow)
  411. def get_bprop_pow(self):
  412. """Grad definition for `Pow` operation."""
  413. pow_op = P.Pow()
  414. ln = P.Log()
  415. def bprop(x, power, out, dout):
  416. bc_dx = power * pow_op(x, power - 1.0) * dout
  417. x[x < 0] = 1
  418. bc_dpower = out * ln(x) * dout
  419. return binop_grad_common(x, power, bc_dx, bc_dpower)
  420. return bprop
  421. @bprop_getters.register(P.Exp)
  422. def get_bprop_exp(self):
  423. """Grad definition for `Exp` operation."""
  424. exp_ = P.Exp()
  425. def bprop(x, out, dout):
  426. g = exp_(x)
  427. dx = g * dout
  428. return (dx,)
  429. return bprop
  430. @bprop_getters.register(P.Expm1)
  431. def get_bprop_expm1(self):
  432. """Grad definition for `Expm1` operation."""
  433. exp_ = P.Exp()
  434. def bprop(x, out, dout):
  435. g = exp_(x)
  436. dx = g * dout
  437. return (dx,)
  438. return bprop
  439. @bprop_getters.register(P.Minimum)
  440. def get_bprop_minimum(self):
  441. """Grad definition for `Minimum` operation."""
  442. input_grad = G.MinimumGrad()
  443. def bprop(x, y, out, dout):
  444. dx, dy = input_grad(x, y, dout)
  445. return dx, dy
  446. return bprop
  447. @bprop_getters.register(P.Maximum)
  448. def get_bprop_maximum(self):
  449. """Grad definition for `Maximum` operation."""
  450. input_grad = G.MaximumGrad()
  451. def bprop(x, y, out, dout):
  452. dx, dy = input_grad(x, y, dout)
  453. return dx, dy
  454. return bprop
  455. @bprop_getters.register(P.ReduceSum)
  456. def get_bprop_reducesum(self):
  457. """Grad definition for `ReduceSum` operation."""
  458. def bprop(x, axis, out, dout):
  459. dx = _sum_grad(x, axis, dout)
  460. return dx, zeros_like(axis)
  461. return bprop
  462. @bprop_getters.register(P.CumSum)
  463. def get_bprop_cumsum(self):
  464. """Grad definition for `CumSum` operation."""
  465. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  466. def bprop(x, axis, out, dout):
  467. return cumsum(dout, axis), zeros_like(axis)
  468. return bprop
  469. @constexpr
  470. def _split_shape_index(input_shape, axis):
  471. """Calculate reduce_prod grad transpose indices and perm shape."""
  472. rank = len(input_shape)
  473. if isinstance(axis, int):
  474. axis = tuple([axis])
  475. reduction_indices = tuple([(i + rank) % rank for i in axis])
  476. other_indices = tuple(set(range(rank)) - set(reduction_indices))
  477. reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
  478. other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
  479. perm = reduction_indices + other_indices
  480. return tuple([reduced_num, other_num]), perm
  481. @constexpr
  482. def _invert_permutation(perm):
  483. """Calculate invert permutation."""
  484. out = [0] * len(perm)
  485. for i, value in enumerate(perm):
  486. out[value] = i
  487. return tuple(out)
  488. @bprop_getters.register(P.ReduceProd)
  489. def get_bprop_reduceprod(self):
  490. """Grad definition for `ReduceProd` operation."""
  491. transpose = P.Transpose()
  492. left_cumprod = P.CumProd(exclusive=True)
  493. right_cumprod = P.CumProd(exclusive=True, reverse=True)
  494. def bprop(x, axis, out, dout):
  495. """Grad definition for `Product` operation."""
  496. # Expand dout to full input shape
  497. input_shape = shape_op(x)
  498. output_shape_kept_dims = reduced_shape(input_shape, axis)
  499. dout = reshape(dout, output_shape_kept_dims)
  500. tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
  501. grad = tile(dout, tile_scaling)
  502. # Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
  503. pack_shape, perm = _split_shape_index(input_shape, axis)
  504. permuted = transpose(x, perm)
  505. permuted_shape = shape_op(permuted)
  506. reshaped = reshape(permuted, pack_shape)
  507. # Calculate product, leaving out the current entry
  508. left = left_cumprod(reshaped, 0)
  509. right = right_cumprod(reshaped, 0)
  510. y = reshape(left * right, permuted_shape)
  511. # Invert the transpose and reshape operations.
  512. # Make sure to set the statically known shape information through a reshape.
  513. out = transpose(y, _invert_permutation(perm)) * grad
  514. dx = reshape(out, input_shape)
  515. return dx, zeros_like(axis)
  516. return bprop
  517. @bprop_getters.register(P.CumProd)
  518. def get_bprop_cumprod(self):
  519. """Grad definition for `CumProd` operation."""
  520. cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
  521. cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
  522. def bprop(x, axis, out, dout):
  523. """Grad definition for `Product` operation."""
  524. # This will fails when x contains 0
  525. prod = cumprod(x, axis)
  526. out = cumsum(prod * dout, axis)
  527. return out / x, zeros_like(axis)
  528. return bprop
  529. @bprop_getters.register(P.ReduceAll)
  530. def get_bprop_reduceall(self):
  531. """Grad definition for `ReduceAll` operation."""
  532. def bprop(x, axis, out, dout):
  533. return zeros_like(x), zeros_like(axis)
  534. return bprop
  535. @bprop_getters.register(P.ReduceAny)
  536. def get_bprop_reduceany(self):
  537. """Grad definition for `ReduceAny` operation."""
  538. def bprop(x, axis, out, dout):
  539. return zeros_like(x), zeros_like(axis)
  540. return bprop
  541. @bprop_getters.register(P.ReduceMax)
  542. def get_bprop_reducemax(self):
  543. """Grad definition for `Max` operation."""
  544. def bprop(x, axis, out, dout):
  545. dx = _min_or_max_grad(x, axis, out, dout)
  546. return (dx, zeros_like(axis))
  547. return bprop
  548. @bprop_getters.register(P.ArgMaxWithValue)
  549. def get_bprop_argmaxwithvalue(self):
  550. """Grad definition for `ArgMaxWithValue` operation."""
  551. axis = self.axis
  552. keep_dims = self.keep_dims
  553. op = P.ArgMaxWithValue(axis)
  554. def bprop(x, out, dout):
  555. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  556. return (dx,)
  557. return bprop
  558. @bprop_getters.register(P.ReduceMin)
  559. def get_bprop_reducemin(self):
  560. """Grad definition for `ReduceMin` operation."""
  561. def bprop(x, axis, out, dout):
  562. dx = _min_or_max_grad(x, axis, out, dout)
  563. return (dx, zeros_like(axis))
  564. return bprop
  565. @bprop_getters.register(P.ArgMinWithValue)
  566. def get_bprop_argminwithvalue(self):
  567. """Generate bprop for ArgMinWithValue"""
  568. axis = self.axis
  569. keep_dims = self.keep_dims
  570. op = P.ArgMinWithValue(axis)
  571. def bprop(x, out, dout):
  572. dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
  573. return (dx,)
  574. return bprop
  575. @bprop_getters.register(P.ReduceMean)
  576. def get_bprop_reduce_mean(self):
  577. """Grad definition for `ReduceMean` operation."""
  578. div_op = P.RealDiv()
  579. cast = P.Cast()
  580. dtype = P.DType()
  581. def bprop(x, axis, out, dout):
  582. grad = _sum_grad(x, axis, dout)
  583. div_shape = F.shape_mul(shape_op(x)) / F.shape_mul(shape_op(out))
  584. dx = div_op(grad, cast(F.scalar_to_array(div_shape), dtype(grad)))
  585. return dx, zeros_like(axis)
  586. return bprop
  587. @bprop_getters.register(P.IsFinite)
  588. def get_bprop_isfinite(self):
  589. """Grad definition for `IsFinite` operation."""
  590. def bprop(x, out, dout):
  591. return (zeros_like(x),)
  592. return bprop
  593. @bprop_getters.register(P.Equal)
  594. def get_bprop_equal(self):
  595. """Grad definition for `Equal` operation."""
  596. def bprop(x, y, out, dout):
  597. return zeros_like(x), zeros_like(y)
  598. return bprop
  599. @bprop_getters.register(P.NotEqual)
  600. def get_bprop_not_equal(self):
  601. """Grad definition for `NotEqual` operation."""
  602. def bprop(x, y, out, dout):
  603. return zeros_like(x), zeros_like(y)
  604. return bprop
  605. @bprop_getters.register(P.ApproximateEqual)
  606. def get_bprop_approximate_equal(self):
  607. """Grad definition for `ApproximateEqual` operation."""
  608. def bprop(x, y, out, dout):
  609. return zeros_like(x), zeros_like(y)
  610. return bprop
  611. @bprop_getters.register(P.Greater)
  612. def get_bprop_greater(self):
  613. """Grad definition for `Greater` operation."""
  614. def bprop(x, y, out, dout):
  615. return zeros_like(x), zeros_like(y)
  616. return bprop
  617. @bprop_getters.register(P.GreaterEqual)
  618. def get_bprop_greater_equal(self):
  619. """Grad definition for `GreaterEqual` operation."""
  620. def bprop(x, y, out, dout):
  621. return zeros_like(x), zeros_like(y)
  622. return bprop
  623. @bprop_getters.register(P.Less)
  624. def get_bprop_less(self):
  625. """Grad definition for `Less` operation."""
  626. def bprop(x, y, out, dout):
  627. return zeros_like(x), zeros_like(y)
  628. return bprop
  629. @bprop_getters.register(P.LessEqual)
  630. def get_bprop_less_equal(self):
  631. """Grad definition for `LessEqual` operation."""
  632. def bprop(x, y, out, dout):
  633. return zeros_like(x), zeros_like(y)
  634. return bprop
  635. @bprop_getters.register(P.LogicalNot)
  636. def get_bprop_logical_not(self):
  637. """Grad definition for `LogicalNot` operation."""
  638. def bprop(x, out, dout):
  639. return (zeros_like(x),)
  640. return bprop
  641. @bprop_getters.register(P.LogicalAnd)
  642. def get_bprop_logical_and(self):
  643. """Grad definition for `LogicalAnd` operation."""
  644. def bprop(x, y, out, dout):
  645. return zeros_like(x), zeros_like(y)
  646. return bprop
  647. @bprop_getters.register(P.LogicalOr)
  648. def get_bprop_logical_or(self):
  649. """Grad definition for `LogicalOr` operation."""
  650. def bprop(x, y, out, dout):
  651. return zeros_like(x), zeros_like(y)
  652. return bprop
  653. @bprop_getters.register(P.NPUAllocFloatStatus)
  654. def get_bprop_npu_alloc_float_status(self):
  655. """Grad definition for `NPUAllocFloatStatus` operation."""
  656. def bprop(out, dout):
  657. return ()
  658. return bprop
  659. @bprop_getters.register(P.NPUGetFloatStatus)
  660. def get_bprop_npu_get_float_status(self):
  661. """Grad definition for `NPUGetFloatStatus` operation."""
  662. def bprop(x, out, dout):
  663. return (zeros_like(x),)
  664. return bprop
  665. @bprop_getters.register(P.NPUClearFloatStatus)
  666. def get_bprop_npu_clear_float_status(self):
  667. """Grad definition for `NPUClearFloatStatus` operation."""
  668. def bprop(x, out, dout):
  669. return (zeros_like(x),)
  670. return bprop
  671. @bprop_getters.register(P.AssignAdd)
  672. def get_bprop_assign_add(self):
  673. """Grad definition for `AssignAdd` operation."""
  674. def bprop(x, y, out, dout):
  675. return zeros_like(x), zeros_like(y)
  676. return bprop
  677. @bprop_getters.register(P.AssignSub)
  678. def get_bprop_assign_sub(self):
  679. """Grad definition for `AssignSub` operation."""
  680. def bprop(x, y, out, dout):
  681. return zeros_like(x), zeros_like(y)
  682. return bprop
  683. @bprop_getters.register(P.Sin)
  684. def get_bprop_sin(self):
  685. """Grad definition for `Sin` operation."""
  686. cos = P.Cos()
  687. def bprop(x, out, dout):
  688. dx = dout * cos(x)
  689. return (dx,)
  690. return bprop
  691. @bprop_getters.register(P.Asin)
  692. def get_bprop_asin(self):
  693. """Grad definition for `Asin` operation."""
  694. input_grad = G.AsinGrad()
  695. def bprop(x, out, dout):
  696. dx = input_grad(x, dout)
  697. return (dx,)
  698. return bprop
  699. @bprop_getters.register(G.AsinGrad)
  700. def get_bprop_asin_grad(self):
  701. """Grad definition for `AsinGrad` operation."""
  702. input_grad = G.AsinGrad()
  703. p_pow = P.Pow()
  704. def bprop(x, grad, out, dout):
  705. d2x = dout * grad * x * p_pow((1 - x * x), - 1.5)
  706. ddy = input_grad(x, dout)
  707. return (d2x, ddy)
  708. return bprop
  709. @bprop_getters.register(P.Asinh)
  710. def get_bprop_asinh(self):
  711. """Grad definition for `Asinh` operation."""
  712. input_grad = G.AsinhGrad()
  713. def bprop(x, out, dout):
  714. dx = input_grad(out, dout)
  715. return (dx,)
  716. return bprop
  717. @bprop_getters.register(G.AsinhGrad)
  718. def get_bprop_asinh_grad(self):
  719. """Grad definition for `AsinhGrad` operation."""
  720. input_grad = G.AsinhGrad()
  721. tanh = P.Tanh()
  722. def bprop(y, grad, out, dout):
  723. dy = dout * out * -1.0 * tanh(y)
  724. dgrad = input_grad(y, dout)
  725. return dy, dgrad
  726. return bprop
  727. @bprop_getters.register(P.Sinh)
  728. def get_bprop_sinh(self):
  729. """Grad definition for `Sinh` operation."""
  730. cosh = P.Cosh()
  731. def bprop(x, out, dout):
  732. dx = cosh(x) * dout
  733. return (dx,)
  734. return bprop
  735. @bprop_getters.register(P.Cos)
  736. def get_bprop_cos(self):
  737. """Grad definition for `Cos` operation."""
  738. sin = P.Sin()
  739. neg = P.Neg()
  740. def bprop(x, out, dout):
  741. dx = dout * neg(sin(x))
  742. return (dx,)
  743. return bprop
  744. @bprop_getters.register(P.ACos)
  745. def get_bprop_acos(self):
  746. """Grad definition for `ACos` operation."""
  747. input_grad = G.ACosGrad()
  748. def bprop(x, out, dout):
  749. dx = input_grad(x, dout)
  750. return (dx,)
  751. return bprop
  752. @bprop_getters.register(G.ACosGrad)
  753. def get_bprop_acos_grad(self):
  754. """Grad definition for `ACosGrad` operation."""
  755. input_grad = G.ACosGrad()
  756. p_pow = P.Pow()
  757. def bprop(x, grad, out, dout):
  758. d2x = -dout * grad * x * p_pow((1 - x * x), - 1.5)
  759. ddy = input_grad(x, dout)
  760. return (d2x, ddy)
  761. return bprop
  762. @bprop_getters.register(P.Acosh)
  763. def get_bprop_acosh(self):
  764. """Grad definition for `Acosh` operation."""
  765. input_grad = G.AcoshGrad()
  766. def bprop(x, out, dout):
  767. dx = input_grad(out, dout)
  768. return (dx,)
  769. return bprop
  770. @bprop_getters.register(G.AcoshGrad)
  771. def get_bprop_acosh_grad(self):
  772. """Grad definition for `AcoshGrad` operation."""
  773. input_grad = G.AcoshGrad()
  774. tanh = P.Tanh()
  775. def bprop(y, grad, out, dout):
  776. dy = dout * out * -1.0 / tanh(y)
  777. dgrad = input_grad(y, dout)
  778. return dy, dgrad
  779. return bprop
  780. @bprop_getters.register(P.Cosh)
  781. def get_bprop_cosh(self):
  782. """Grad definition for `Cosh` operation."""
  783. sinh = P.Sinh()
  784. def bprop(x, out, dout):
  785. dx = sinh(x) * dout
  786. return (dx,)
  787. return bprop
  788. @bprop_getters.register(P.Abs)
  789. def get_bprop_abs(self):
  790. """Grad definition for `Abs` operation."""
  791. abs_grad = G.AbsGrad()
  792. def bprop(x, out, dout):
  793. dx = abs_grad(x, dout)
  794. return (dx,)
  795. return bprop
  796. @bprop_getters.register(P.ScalarCast)
  797. def get_bprop_scalar_cast(self):
  798. """Generate bprop for ScalarCast"""
  799. def bprop(x, t, out, dout):
  800. return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
  801. return bprop
  802. @bprop_getters.register(P.AccumulateNV2)
  803. def get_bprop_scalar_accumulatenv2(self):
  804. """Generate bprop for AccumulateNV2"""
  805. def bprop(x, out, dout):
  806. dx = ()
  807. for _ in range(len(x)):
  808. dx = dx + (dout,)
  809. return (dx,)
  810. return bprop
  811. @bprop_getters.register(P.AddN)
  812. def get_bprop_scalar_addn(self):
  813. """Generate bprop for AddN"""
  814. def bprop(x, out, dout):
  815. if is_sub_class(F.typeof(x), ms.list_):
  816. dx = []
  817. for _ in range(len(x)):
  818. dx.append(dout)
  819. return (dx,)
  820. dx = ()
  821. for _ in range(len(x)):
  822. dx = dx + (dout,)
  823. return (dx,)
  824. return bprop
  825. @bprop_getters.register(P.Sign)
  826. def get_bprop_sign(self):
  827. """Generate bprop for Sign"""
  828. def bprop(x, out, dout):
  829. return (zeros_like(x),)
  830. return bprop
  831. @bprop_getters.register(P.Round)
  832. def get_bprop_round(self):
  833. """Generate bprop for Round"""
  834. def bprop(x, out, dout):
  835. return (zeros_like(x),)
  836. return bprop
  837. @bprop_getters.register(P.Atan2)
  838. def get_bprop_atan2(self):
  839. """Generate bprop for Atan2"""
  840. square = P.Square()
  841. def bprop(x, y, out, dout):
  842. tmp = dout / (square(x) + square(y))
  843. bc_dx = tmp * y
  844. bc_dy = tmp * (-x)
  845. return binop_grad_common(x, y, bc_dx, bc_dy)
  846. return bprop
  847. @bprop_getters.register(P.BesselI0e)
  848. def get_bprop_bessel_i0e(self):
  849. """Generate bprop for BesselI0e"""
  850. sign = P.Sign()
  851. bessel_i1e = P.BesselI1e()
  852. def bprop(x, out, dout):
  853. dx = dout * (bessel_i1e(x) - sign(x) * out)
  854. return (dx,)
  855. return bprop
  856. @bprop_getters.register(P.Atan)
  857. def get_bprop_atan(self):
  858. """Grad definition for `Atan` operation."""
  859. input_grad = G.AtanGrad()
  860. def bprop(x, out, dout):
  861. dx = input_grad(x, dout)
  862. return (dx,)
  863. return bprop
  864. @bprop_getters.register(G.AtanGrad)
  865. def get_bprop_atan_grad(self):
  866. """Grad definition for `AtanGrad` operation."""
  867. input_grad = G.AtanGrad()
  868. def bprop(x, grad, out, dout):
  869. dgrad = input_grad(x, dout)
  870. dx = out * dgrad * -2.0 * x
  871. return dx, dgrad
  872. return bprop
  873. @bprop_getters.register(P.Tan)
  874. def get_bprop_tan(self):
  875. """Grad definition for `Tan` operation."""
  876. reciprocal = P.Reciprocal()
  877. square = P.Square()
  878. cos = P.Cos()
  879. def bprop(x, out, dout):
  880. cosx = cos(x)
  881. secx2 = square(reciprocal(cosx))
  882. dx = secx2 * dout
  883. return (dx,)
  884. return bprop
  885. @bprop_getters.register(P.BesselI1e)
  886. def get_bprop_bessel_i1e(self):
  887. """Generate bprop for BesselI1e"""
  888. sign = P.Sign()
  889. bessel_i0e = P.BesselI0e()
  890. less = P.Less()
  891. select = P.Select()
  892. reciprocal = P.Reciprocal()
  893. cast = P.Cast()
  894. dtype = P.DType()
  895. abs_ops = P.Abs()
  896. def bprop(x, out, dout):
  897. zeros = zeros_like(x)
  898. np_eps = const_utils.get_np_eps(dtype(x))
  899. eps = cast(np_eps, dtype(x))
  900. x_is_valid = less(eps, abs_ops(x))
  901. x_safe = select(x_is_valid, x, eps + zeros)
  902. tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
  903. dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
  904. return (dx,)
  905. return bprop
  906. @bprop_getters.register(P.Atanh)
  907. def get_bprop_atanh(self):
  908. """Grad definition for `Atanh` operation."""
  909. power = P.Pow()
  910. div = P.Div()
  911. def bprop(x, out, dout):
  912. tmp = 1 - power(x, 2)
  913. dx = div(1, tmp) * dout
  914. return (dx,)
  915. return bprop
  916. @bprop_getters.register(P.Inv)
  917. def get_bprop_inv(self):
  918. """Grad definition for 'Inv' operation"""
  919. inv_grad = G.InvGrad()
  920. def bprop(x, out, dout):
  921. dx = inv_grad(out, dout)
  922. return (dx,)
  923. return bprop
  924. @bprop_getters.register(P.LinSpace)
  925. def get_bprop_lin_space(self):
  926. """Grad definition for `LinSpace` operation."""
  927. def bprop(start, stop, num, out, dout):
  928. return zeros_like(start), zeros_like(stop), zeros_like(num)
  929. return bprop
  930. @bprop_getters.register(P.IndexAdd)
  931. def get_bprop_index_add(self):
  932. """Generate bprop for IndexAdd"""
  933. gather = P.Gather()
  934. _axis = self.axis
  935. def bprop(input_x, indices, input_y, out, dout):
  936. return dout, zeros_like(indices), gather(dout, indices, _axis)
  937. return bprop