You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 32 kB

5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. import numpy as np
  17. import mindspore as ms
  18. from mindspore.ops import composite as C
  19. from .. import operations as P
  20. from ..operations import _grad_ops as G
  21. from ..operations import _inner_ops as inner
  22. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  23. from ..functional import broadcast_gradient_args
  24. from .. import functional as F
  25. from .grad_base import bprop_getters
  26. from ..primitive import constexpr
  27. from ... import context
  28. from ...common import dtype as mstype
  29. from ...common.tensor import RowTensor
  30. from .._utils.utils import range_op, get_1d_shape, generate_shape_index, is_shape_unknown
  31. reduce_sum = P.ReduceSum()
  32. unsorted_segment_sum = P.UnsortedSegmentSum()
  33. transpose = P.Transpose()
  34. shape_op = P.Shape()
  35. dyn_shape_op = P.DynamicShape()
  36. reshape = P.Reshape()
  37. size_op = P.Size()
  38. invert_permutation = P.InvertPermutation()
  39. logical_and = P.LogicalAnd()
  40. is_sub_class = P.IsSubClass()
  41. @bprop_getters.register(P.Fill)
  42. def get_bprop_fill(self):
  43. """Generate bprop for Fill"""
  44. def bprop(dtype, dims, x, out, dout):
  45. return zeros_like(dims), zeros_like(x)
  46. return bprop
  47. @bprop_getters.register(P.Ones)
  48. def get_bprop_ones(self):
  49. """Generate bprop for Ones"""
  50. def bprop(dims, dtype, out, dout):
  51. return zeros_like(dims)
  52. return bprop
  53. @bprop_getters.register(P.Zeros)
  54. def get_bprop_zeros(self):
  55. """Generate bprop for Zeros"""
  56. def bprop(dims, dtype, out, dout):
  57. return zeros_like(dims)
  58. return bprop
  59. @bprop_getters.register(P.DType)
  60. def get_bprop_dtype(self):
  61. """Generate bprop for DType"""
  62. def bprop(x, out, dout):
  63. return (zeros_like(x),)
  64. return bprop
  65. dout_cast = C.MultitypeFuncGraph("dout_cast")
  66. @dout_cast.register("Tensor", "Tensor")
  67. def dout_cast_tensor(dout, x):
  68. """Casts dout to the dtype of x for Tensor."""
  69. cast = P.Cast()
  70. get_dtype = P.DType()
  71. dx = cast(dout, get_dtype(x))
  72. return dx
  73. @dout_cast.register("Number", "Number")
  74. def dout_cast_number(dout, x):
  75. """Casts dout to the dtype of x for Number."""
  76. cast = P.Cast()
  77. get_dtype = P.DType()
  78. dx = cast(dout, get_dtype(x))
  79. return dx
  80. @dout_cast.register("RowTensor", "Tensor")
  81. def dout_cast_row_tensor(dout, x):
  82. """Casts dout values to the dtype of x for RowTensor."""
  83. cast = P.Cast()
  84. get_dtype = P.DType()
  85. values = cast(dout.values, get_dtype(x))
  86. return RowTensor(dout.indices, values, dout.dense_shape)
  87. @bprop_getters.register(P.Cast)
  88. def get_bprop_cast(self):
  89. """Generate bprop for Cast"""
  90. cast = P.Cast()
  91. get_dtype = P.DType()
  92. def bprop(x, t, out, dout):
  93. dx = cast(dout, get_dtype(x))
  94. return dx, zeros_like(t)
  95. def bprop_sparse(x, t, out, dout):
  96. dx = dout_cast(dout, x)
  97. return dx, zeros_like(t)
  98. if context.get_context('enable_sparse'):
  99. return bprop_sparse
  100. return bprop
  101. @bprop_getters.register(P.Shape)
  102. def get_bprop_shape(self):
  103. """Generate bprop for Shape"""
  104. def bprop(x, out, dout):
  105. return (zeros_like(x),)
  106. return bprop
  107. @bprop_getters.register(P.DynamicShape)
  108. def get_bprop_dynamicshape(self):
  109. """Generate bprop for Shape"""
  110. def bprop(x, out, dout):
  111. return (zeros_like(x),)
  112. return bprop
  113. @bprop_getters.register(P.Split)
  114. def get_bprop_split(self):
  115. """Generate bprop for Split"""
  116. axis = self.axis
  117. def bprop(x, out, dout):
  118. concat_op = P.Concat(axis)
  119. dx = concat_op(dout)
  120. return (dx,)
  121. return bprop
  122. @bprop_getters.register(P.Rank)
  123. def get_bprop_rank(self):
  124. """Generate bprop for Rank"""
  125. def bprop(x, out, dout):
  126. return (zeros_like(x),)
  127. return bprop
  128. @bprop_getters.register(P.Reshape)
  129. def get_bprop_reshape(self):
  130. """Generate bprop for Reshape"""
  131. def bprop(x, shp, out, dout):
  132. shapex = shape_op(x)
  133. if -1 in shapex:
  134. shapex = dyn_shape_op(x)
  135. return reshape(dout, shapex), zeros_like(shp)
  136. return bprop
  137. @bprop_getters.register(P.ExpandDims)
  138. def get_bprop_expand_dims(self):
  139. """Generate bprop for ExpandDims"""
  140. def bprop(x, axis, out, dout):
  141. shapex = shape_op(x)
  142. if is_shape_unknown(shapex):
  143. shapex = dyn_shape_op(x)
  144. return reshape(dout, shapex), zeros_like(axis)
  145. return bprop
  146. @bprop_getters.register(P.Squeeze)
  147. def get_bprop_squeeze(self):
  148. """Generate bprop for Squeeze"""
  149. def bprop(x, out, dout):
  150. shapex = shape_op(x)
  151. return (reshape(dout, shapex),)
  152. return bprop
  153. @bprop_getters.register(P.Flatten)
  154. def get_bprop_flatten(self):
  155. """Generate bprop for Flatten"""
  156. flatten_grad = P.Reshape()
  157. def bprop(x, out, dout):
  158. dx = flatten_grad(dout, shape_op(x))
  159. return (dx,)
  160. return bprop
  161. @constexpr
  162. def _tile_shape(multiples, shapex):
  163. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  164. len_muli = len(multiples)
  165. rank = len(shapex)
  166. len_cmp = len_muli - rank
  167. max_len = max(len_muli, rank)
  168. i = 0
  169. j = 0
  170. ret = []
  171. while (i < max_len) and (j < max_len):
  172. if len_cmp == 0:
  173. ret.append(multiples[i])
  174. ret.append(shapex[j])
  175. i += 1
  176. j += 1
  177. elif len_cmp > 0:
  178. ret.append(multiples[i])
  179. ret.append(1)
  180. i += 1
  181. len_cmp -= 1
  182. else:
  183. ret.append(1)
  184. ret.append(shapex[j])
  185. j += 1
  186. len_cmp += 1
  187. return tuple(ret)
  188. @bprop_getters.register(P.Tile)
  189. def get_bprop_tile(self):
  190. """Generate bprop for Tile"""
  191. tuple_to_array = P.TupleToArray()
  192. cast = P.Cast()
  193. stack_op = P.Stack(1)
  194. ones = P.Ones()
  195. concat = P.Concat()
  196. def bprop(x, multiples, out, dout):
  197. shapex = shape_op(x)
  198. if -1 in shapex:
  199. shapex = dyn_shape_op(x)
  200. # if shapex or multiples not tuple, it should be dynamic shape.
  201. if isinstance(multiples, tuple) and isinstance(shapex, tuple):
  202. r_shape = _tile_shape(multiples, shapex)
  203. else:
  204. if isinstance(multiples, tuple):
  205. multiples = tuple_to_array(multiples)
  206. multiples = cast(multiples, mstype.int64)
  207. len_multi = size_op(multiples)
  208. rank = len(shapex)
  209. if isinstance(shapex, tuple):
  210. shape_tensor = cast(tuple_to_array(shapex), mstype.int64)
  211. else:
  212. shape_tensor = shapex
  213. if len_multi > rank:
  214. one_tensor = ones((len_multi - rank,), mstype.int64)
  215. shape_tensor = concat((one_tensor, shape_tensor))
  216. elif len_multi < rank:
  217. one_tensor = ones((rank - len_multi,), mstype.int64)
  218. multiples = concat((one_tensor, multiples))
  219. tile_shape = stack_op((multiples, shape_tensor))
  220. r_shape = P.Reshape()(tile_shape, (-1,))
  221. # 0 represents the start index, and 2 represents the step
  222. axis = F.make_range(0, len(r_shape), 2)
  223. dx = reduce_sum(P.Reshape()(dout, r_shape), axis)
  224. dx = reshape(dx, shapex)
  225. return dx, zeros_like(multiples)
  226. return bprop
  227. @bprop_getters.register(P.EmbeddingLookup)
  228. def get_bprop_embedding_lookup(self):
  229. """Generate bprop for EmbeddingLookup"""
  230. sub_op = P.Sub()
  231. reshape_op = P.Reshape()
  232. def bprop_sparse(x, indices, offset, out, dout):
  233. x_shp = shape_op(x)
  234. new_indices = sub_op(indices, offset)
  235. indices_size = size_op(new_indices)
  236. if indices_size > 0:
  237. # Reshape the 'new_indices'
  238. new_indices_shape_changed = (indices_size,)
  239. new_indices = reshape_op(new_indices, new_indices_shape_changed)
  240. else:
  241. new_indices_shape_changed = ()
  242. x_shp_tail = x_shp[1:]
  243. actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
  244. # Reshape the 'actual_dout' on device
  245. actual_dout = reshape_op(dout, actual_dout_shape_changed)
  246. return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
  247. return bprop_sparse
  248. @constexpr
  249. def make_begin(shp):
  250. """Creates a tuple with zero according to the shape."""
  251. begin = tuple([0 for _ in shp])
  252. return begin
  253. @bprop_getters.register(P.Padding)
  254. def get_bprop_padding(self):
  255. """Grad definition for `Padding` operation."""
  256. def bprop(x, out, dout):
  257. shp = shape_op(x)
  258. begin = make_begin(shp)
  259. dx = P.Slice()(dout, begin, shp)
  260. return (dx,)
  261. return bprop
  262. @bprop_getters.register(P.Transpose)
  263. def get_bprop_transpose(self):
  264. """Generate bprop for Transpose"""
  265. def bprop(x, perm, out, dout):
  266. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  267. return bprop
  268. @constexpr
  269. def _concat_grad_uniform(input_shapes, input_nums):
  270. """Helper function for bprop of Concat"""
  271. is_uniform = True
  272. for i in range(1, input_nums):
  273. if input_shapes[i - 1] != input_shapes[i]:
  274. is_uniform = False
  275. break
  276. return is_uniform
  277. @bprop_getters.register(P.Concat)
  278. def get_bprop_concat(self):
  279. """Generate bprop for Concat"""
  280. axis = self.axis
  281. def bprop(x, out, dout):
  282. out_offset = G.ConcatOffset(len(x), axis)(x)
  283. input_nums = len(x)
  284. input_shapes = ()
  285. if isinstance(out_offset, tuple):
  286. for i in range(input_nums):
  287. input_shapes = input_shapes + (shape_op(x[i]),)
  288. is_uniform = _concat_grad_uniform(input_shapes, input_nums)
  289. else:
  290. # for dynamic shape
  291. for i in range(input_nums):
  292. input_shapes = input_shapes + (dyn_shape_op(x[i]),)
  293. is_uniform = False
  294. if isinstance(x, list):
  295. dx = []
  296. if is_uniform:
  297. dx_tuple = P.Split(axis, input_nums)(dout)
  298. for _, i in enumerate(dx_tuple):
  299. dx = dx + [i,]
  300. else:
  301. for i in range(input_nums):
  302. slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
  303. dx = dx + [slice_out,]
  304. else:
  305. dx = ()
  306. if is_uniform:
  307. dx = P.Split(axis, input_nums)(dout)
  308. else:
  309. for i in range(input_nums):
  310. slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
  311. dx = dx + (slice_out,)
  312. return (dx,)
  313. return bprop
  314. @constexpr
  315. def _slice_grad_pad(begins, sizes, shapes):
  316. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  317. return pads
  318. @bprop_getters.register(P.Slice)
  319. def get_bprop_slice(self):
  320. """Generate bprop for Slice"""
  321. def bprop(x, begin, size, out, dout):
  322. dx = G.SliceGrad()(dout, x, begin, size)
  323. return (dx, zeros_like(begin), zeros_like(size))
  324. return bprop
  325. @constexpr
  326. def _generate_inverse_index(x_shape, axis):
  327. x_rank = len(x_shape)
  328. index = tuple(range(x_rank))
  329. if axis < 0:
  330. axis += x_rank
  331. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  332. return perm
  333. @constexpr
  334. def _regenerate_output_shape(x_shp, ind_shp, axis):
  335. rank = len(x_shp)
  336. if axis < 0:
  337. axis += rank
  338. out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:]
  339. return out_shape
  340. @bprop_getters.register(P.Gather)
  341. @bprop_getters.register(P.GatherV2)
  342. def get_bprop_gather_v2(self):
  343. """Generate bprop for GatherV2"""
  344. def bprop(x, indices, axis, out, dout):
  345. orig_indices = indices
  346. if F.rank(dout) == 0:
  347. dout = P.ExpandDims()(dout, -1)
  348. if F.rank(indices) == 0:
  349. indices = P.ExpandDims()(indices, -1)
  350. x_shp = shape_op(x)
  351. ind_shp = shape_op(indices)
  352. out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
  353. dout = reshape(dout, out_shp)
  354. x_shp = shape_op(x)
  355. out_shp = shape_op(dout)
  356. ind_shp = shape_op(indices)
  357. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  358. perm_1 = generate_shape_index(out_shp, ind_shp, axis)
  359. values_transpose = transpose(dout, perm_1)
  360. if -1 in shape_op(x):
  361. params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
  362. else:
  363. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  364. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  365. perm_2 = _generate_inverse_index(x_shp, axis)
  366. params_grad = transpose(params_grad, perm_2)
  367. return params_grad, zeros_like(orig_indices), zeros_like(axis)
  368. return bprop
  369. @bprop_getters.register(P.GatherD)
  370. def get_bprop_gather_d(self):
  371. """Generate bprop for GatherD"""
  372. def bprop(x, dim, index, out, dout):
  373. x_shp = shape_op(x)
  374. dx = G.GatherDGrad(dim, x_shp)(index, dout)
  375. return dx, zeros_like(dim), zeros_like(index)
  376. return bprop
  377. @bprop_getters.register(G.GatherDGrad)
  378. def get_bprop_gather_d_grad(self):
  379. """Generate bprop for GatherDGrad"""
  380. op = P.Gather()
  381. dim = self.dim
  382. x_shp = self.out_shape
  383. def bprop(index, x, out, dout):
  384. index_shp = shape_op(index)
  385. dim_before_axis = 1
  386. for i in range(dim):
  387. dim_before_axis *= x_shp[i]
  388. dim_at_axis_index = index_shp[dim]
  389. dim_at_axis_output = x_shp[dim]
  390. dim_after_axis = 1
  391. for i in range(dim+1, len(x_shp)):
  392. dim_after_axis *= x_shp[i]
  393. element = dim_before_axis * dim_at_axis_index * dim_after_axis
  394. id_ = range_op(0, element, 1, index.dtype)
  395. i = id_ // (dim_at_axis_index * dim_after_axis)
  396. k = id_ % dim_after_axis
  397. j = P.Cast()(index < 0, index.dtype)
  398. j_read = dim_at_axis_index * j + index
  399. j_read = P.Reshape()(j_read, (-1,))
  400. read_id = i*dim_at_axis_output*dim_after_axis + j_read * dim_after_axis + k
  401. dout = P.Reshape()(dout, (-1,))
  402. dx = op(dout, read_id, 0)
  403. dx = P.Reshape()(dx, shape_op(x))
  404. return zeros_like(index), dx
  405. return bprop
  406. @bprop_getters.register(P.SparseGatherV2)
  407. def get_bprop_sparse_gather_v2(self):
  408. """Generate bprop for SparseGatherV2"""
  409. def bprop(x, indices, axis, out, dout):
  410. x_shp = shape_op(x)
  411. if axis == 0:
  412. indices_size = (size_op(indices),)
  413. if len(x_shp) <= 1:
  414. x_tail_shp = ()
  415. else:
  416. x_tail_shp = x_shp[1:]
  417. values_shape = indices_size + x_tail_shp
  418. values = reshape(dout, values_shape)
  419. indices_new = reshape(indices, indices_size)
  420. return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
  421. if F.rank(dout) == 0:
  422. dout = P.ExpandDims()(dout, -1)
  423. if F.rank(indices) == 0:
  424. indices = P.ExpandDims()(indices, -1)
  425. out_shp = shape_op(dout)
  426. ind_shp = shape_op(indices)
  427. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  428. perm_1 = generate_shape_index(out_shp, ind_shp, axis)
  429. values_transpose = transpose(dout, perm_1)
  430. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  431. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  432. perm_2 = _generate_inverse_index(x_shp, axis)
  433. params_grad = transpose(params_grad, perm_2)
  434. return params_grad, zeros_like(indices), zeros_like(axis)
  435. return bprop
  436. @constexpr
  437. def _get_transposition(axis, rank):
  438. """helper function for grad of Sort"""
  439. if axis < 0:
  440. axis += rank
  441. transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
  442. trans = tuple(transposition.tolist())
  443. return trans
  444. @bprop_getters.register(P.Sort)
  445. def get_bprop_sort(self):
  446. """Grad definition for `Sort` operation."""
  447. axis = self.axis
  448. descending = self.descending
  449. scatter = P.ScatterNd()
  450. expand_dims = P.ExpandDims()
  451. reshape_op = P.Reshape()
  452. dtype = P.DType()
  453. topk = P.TopK()
  454. neg = P.Neg()
  455. tranpose = P.Transpose()
  456. def bprop(input_x, out, dout):
  457. x_shape = input_x.shape
  458. k = x_shape[axis]
  459. rank = F.rank(input_x)
  460. dvalue = dout[0]
  461. if not descending:
  462. input_x = neg(input_x)
  463. dvalue = neg(dvalue)
  464. if axis == -1 or (axis + 1) == rank:
  465. transposition = None
  466. top_k_input = input_x
  467. else:
  468. transposition = _get_transposition(axis, rank)
  469. top_k_input = tranpose(input_x, transposition)
  470. _, indices = topk(top_k_input, k)
  471. ind_shape = indices.shape
  472. top_k_input_shape = top_k_input.shape
  473. in_lastdim = top_k_input_shape[-1]
  474. ind_lastdim = ind_shape[-1]
  475. ind_2d = reshape_op(indices, (-1, ind_lastdim))
  476. outer_dim = ind_2d.shape[0]
  477. indices_dtype = dtype(indices)
  478. range_flatten_index = range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
  479. # expand_dims to (k, 1), then broadcast
  480. ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
  481. x_shape_1d = get_1d_shape(top_k_input_shape)
  482. if transposition is not None:
  483. dvalue = tranpose(dvalue, invert_permutation(transposition))
  484. out_grad = reshape_op(
  485. scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
  486. dx = tranpose(out_grad, invert_permutation(transposition))
  487. else:
  488. dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
  489. if not descending:
  490. dx = neg(dx)
  491. return (dx,)
  492. return bprop
  493. @bprop_getters.register(P.Identity)
  494. def get_bprop_identity(self):
  495. """Generate bprop for Identity"""
  496. def bprop(x, out, dout):
  497. return (dout,)
  498. return bprop
  499. @bprop_getters.register(inner.Range)
  500. def get_bprop_range(self):
  501. """Generate bprop for Range"""
  502. def bprop(x, out, dout):
  503. return (zeros_like(x),)
  504. return bprop
  505. @bprop_getters.register(P.Pack)
  506. @bprop_getters.register(P.Stack)
  507. def get_bprop_stack(self):
  508. """Generate bprop for Stack"""
  509. axis = self.axis
  510. def bprop(x, out, dout):
  511. stack_grad = P.Unstack(axis)
  512. out = stack_grad(dout)
  513. if is_sub_class(F.typeof(x), ms.list_):
  514. ret = []
  515. for item in out:
  516. ret.append(item)
  517. return (ret,)
  518. return (out,)
  519. return bprop
  520. @bprop_getters.register(P.ReverseV2)
  521. def get_bprop_reverse_v2(self):
  522. """Generate bprop for ReverseV2"""
  523. axis = self.axis
  524. def bprop(x, out, dout):
  525. reverse_grad = P.ReverseV2(axis)
  526. dx = reverse_grad(dout)
  527. return (dx,)
  528. return bprop
  529. @bprop_getters.register(P.Unstack)
  530. def get_bprop_unstack(self):
  531. """Generate bprop for Unstack"""
  532. axis = self.axis
  533. def bprop(x, out, dout):
  534. unstack_grad = P.Stack(axis)
  535. out = unstack_grad(dout)
  536. return (out,)
  537. return bprop
  538. @bprop_getters.register(P.StridedSlice)
  539. def get_bprop_strided_slice(self):
  540. """Generate bprop for StridedSlice"""
  541. input_grad = G.StridedSliceGrad(self.begin_mask,
  542. self.end_mask,
  543. self.ellipsis_mask,
  544. self.new_axis_mask,
  545. self.shrink_axis_mask)
  546. def bprop(x, begin, end, strides, out, dout):
  547. x_shape = shape_op(x)
  548. if -1 in x_shape:
  549. x_shape = dyn_shape_op(x)
  550. dx = input_grad(dout, x_shape, begin, end, strides)
  551. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  552. return bprop
  553. @bprop_getters.register(P.Eye)
  554. def get_bprop_eye(self):
  555. """Generate bprop for Eye"""
  556. def bprop(n, m, t, out, dout):
  557. return zeros_like(n), zeros_like(m), zeros_like(t)
  558. return bprop
  559. @bprop_getters.register(P.Select)
  560. def get_bprop_select(self):
  561. """Generate bprop for Select"""
  562. select = P.Select()
  563. def bprop(cond, x, y, out, dout):
  564. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  565. return bprop
  566. @bprop_getters.register(P.OnesLike)
  567. def get_bprop_oneslike(self):
  568. """Generate bprop for OnesLike"""
  569. def bprop(x, out, dout):
  570. return (zeros_like(x),)
  571. return bprop
  572. @bprop_getters.register(P.ZerosLike)
  573. def get_bprop_zeroslike(self):
  574. """Generate bprop for ZerosLike"""
  575. def bprop(x, out, dout):
  576. return (zeros_like(x),)
  577. return bprop
  578. @bprop_getters.register(P.ResizeNearestNeighbor)
  579. def get_bprop_resize_nearest_neighbor(self):
  580. """Generate bprop for ResizeNearestNeighbor"""
  581. op = G.ResizeNearestNeighborGrad(self.align_corners)
  582. def bprop(inputs, out, dout):
  583. shp = shape_op(inputs)
  584. # 2 and 3 represent the height and width
  585. shp = (shp[2], shp[3])
  586. return (op(dout, shp),)
  587. return bprop
  588. @bprop_getters.register(P.GatherNd)
  589. def get_bprop_gather_nd(self):
  590. """Generate bprop for GatherNd"""
  591. op = P.ScatterNd()
  592. def bprop(x, indices, out, dout):
  593. shp = shape_op(x)
  594. return op(indices, dout, shp), zeros_like(indices)
  595. return bprop
  596. @bprop_getters.register(P.ScatterNd)
  597. def get_bprop_scatter_nd(self):
  598. """Generate bprop for ScatterNd"""
  599. op = P.GatherNd()
  600. def bprop(indices, x, shape, out, dout):
  601. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  602. return bprop
  603. @bprop_getters.register(P.ScatterNdUpdate)
  604. def get_bprop_scatter_nd_update(self):
  605. """Generate bprop for ScatterNdUpdate"""
  606. op = P.GatherNd()
  607. def bprop(x, indices, update, out, dout):
  608. return dout, zeros_like(indices), op(dout, indices)
  609. return bprop
  610. @bprop_getters.register(P.ScatterNonAliasingAdd)
  611. def get_bprop_scatter_non_aliasing_add_update(self):
  612. """Generate bprop for ScatterNonAliasingAdd"""
  613. op = P.GatherNd()
  614. def bprop(x, indices, update, out, dout):
  615. return dout, zeros_like(indices), op(dout, indices)
  616. return bprop
  617. @bprop_getters.register(P.TensorScatterUpdate)
  618. def get_bprop_tensor_scatter_update(self):
  619. """Generate bprop for TensorScatterUpdate"""
  620. gather_nd = P.GatherNd()
  621. tensor_scatter_update = P.TensorScatterUpdate()
  622. def bprop(x, indices, update, out, dout):
  623. x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
  624. update_grad = gather_nd(dout, indices)
  625. return x_grad, zeros_like(indices), update_grad
  626. return bprop
  627. @bprop_getters.register(P.TensorScatterAdd)
  628. def get_bprop_tensor_scatter_add(self):
  629. """Generate bprop for TensorScatterAdd"""
  630. gather_nd = P.GatherNd()
  631. def bprop(x, indices, update, out, dout):
  632. update_grad = gather_nd(dout, indices)
  633. return dout, zeros_like(indices), update_grad
  634. return bprop
  635. @bprop_getters.register(P.ScatterMax)
  636. def get_bprop_scatter_max(self):
  637. """Generate bprop for ScatterMax"""
  638. gather = P.Gather()
  639. def bprop(x, indices, update, out, dout):
  640. return dout, zeros_like(indices), gather(dout, indices, 0)
  641. return bprop
  642. @bprop_getters.register(P.Argmax)
  643. def get_bprop_argmax(self):
  644. """Generate bprop for Argmax"""
  645. def bprop(x, out, dout):
  646. return (zeros_like(x),)
  647. return bprop
  648. @bprop_getters.register(P.Argmin)
  649. def get_bprop_argmin(self):
  650. """Generate bprop for Argmin"""
  651. def bprop(x, out, dout):
  652. return (zeros_like(x),)
  653. return bprop
  654. @bprop_getters.register(P.SpaceToDepth)
  655. def get_bprop_space_to_depth(self):
  656. """Generate bprop for SpaceToDepth"""
  657. op = P.DepthToSpace(self.block_size)
  658. def bprop(x, out, dout):
  659. return (op(dout),)
  660. return bprop
  661. @bprop_getters.register(P.DepthToSpace)
  662. def get_bprop_depth_to_space(self):
  663. """Generate bprop for DepthToSpace"""
  664. op = P.SpaceToDepth(self.block_size)
  665. def bprop(x, out, dout):
  666. return (op(dout),)
  667. return bprop
  668. @bprop_getters.register(P.Diag)
  669. def get_bprop_diag(self):
  670. """Generate bprop for Diag"""
  671. op = P.DiagPart()
  672. def bprop(x, out, dout):
  673. return (op(dout),)
  674. return bprop
  675. @bprop_getters.register(P.DiagPart)
  676. def get_bprop_diag_part(self):
  677. """Generate bprop for DiagPart"""
  678. op = P.Diag()
  679. def bprop(x, out, dout):
  680. return (op(dout),)
  681. return bprop
  682. def _gather_drop_negatives(params,
  683. ids,
  684. zero_clipped_indices=None,
  685. is_positive=None):
  686. """Helper function for unsorted segment ops."""
  687. maximum = P.Maximum()
  688. gather = P.Gather()
  689. greater_equal = P.GreaterEqual()
  690. rank = P.Rank()
  691. fill = P.Fill()
  692. select = P.Select()
  693. if zero_clipped_indices is None:
  694. zero_clipped_indices = maximum(ids, zeros_like(ids))
  695. gathered = gather(params, zero_clipped_indices, 0)
  696. if is_positive is None:
  697. is_positive = greater_equal(ids, 0)
  698. is_positive_shape = shape_op(is_positive)
  699. broadcastable_shape = is_positive_shape
  700. for _ in range(rank(gathered) - rank(is_positive)):
  701. broadcastable_shape += (1,)
  702. is_positive = reshape(is_positive, broadcastable_shape)
  703. gathered_shape = shape_op(gathered)
  704. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  705. zero_slice = zeros_like(gathered)
  706. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  707. def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
  708. """Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
  709. equal = P.Equal()
  710. cast = P.Cast()
  711. divide = P.RealDiv()
  712. get_dtype = P.DType()
  713. select = P.Select()
  714. gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
  715. is_selected = equal(x, gathered_outputs)
  716. is_selected = logical_and(is_selected, is_positive)
  717. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  718. segment_ids, num_segments)
  719. weighted_grads = divide(dout, num_selected)
  720. gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
  721. zero_clipped_indices, is_positive)
  722. zeros = zeros_like(gathered_grads)
  723. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  724. @bprop_getters.register(P.UnsortedSegmentSum)
  725. def get_bprop_unsorted_segment_sum(self):
  726. """Generate bprop for UnsortedSegmentSum"""
  727. def bprop(x, segment_ids, num_segments, out, dout):
  728. return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
  729. zeros_like(num_segments)
  730. return bprop
  731. @bprop_getters.register(P.UnsortedSegmentMin)
  732. def get_bprop_unsorted_segment_min(self):
  733. """Generate bprop for UnsortedSegmentMin"""
  734. def bprop(x, segment_ids, num_segments, out, dout):
  735. return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
  736. return bprop
  737. @bprop_getters.register(P.UnsortedSegmentMax)
  738. def get_bprop_unsorted_segment_max(self):
  739. """Generate bprop for UnsortedSegmentMax"""
  740. def bprop(x, segment_ids, num_segments, out, dout):
  741. return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
  742. return bprop
  743. @bprop_getters.register(P.UnsortedSegmentProd)
  744. def get_bprop_unsorted_segment_prod(self):
  745. """Generate bprop for UnsortedSegmentProd"""
  746. equal = P.Equal()
  747. cast = P.Cast()
  748. select = P.Select()
  749. gather = P.Gather()
  750. greater = P.Greater()
  751. ones_like = P.OnesLike()
  752. maximum = P.Maximum()
  753. unsorted_segment_prod = P.UnsortedSegmentProd()
  754. def bprop(x, segment_ids, num_segments, out, dout):
  755. is_zero = equal(x, 0)
  756. num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
  757. grad = select(greater(num_zero, 1), zeros_like(dout), dout)
  758. non_zero_data = select(is_zero, ones_like(x), x)
  759. non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
  760. zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
  761. gathered_prod = gather(out, zero_clipped_indices, 0)
  762. gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
  763. prod_divided_by_x = gathered_prod / x
  764. partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
  765. gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
  766. dx = gathered_grad * partial_derivative
  767. return dx, zeros_like(segment_ids), zeros_like(num_segments)
  768. return bprop
  769. @bprop_getters.register(P.SpaceToBatch)
  770. def get_bprop_space_to_batch(self):
  771. """Generate bprop for SpaceToBatch"""
  772. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  773. def bprop(x, out, dout):
  774. dx = space_to_batch_grad(dout)
  775. return (dx,)
  776. return bprop
  777. @bprop_getters.register(P.BatchToSpace)
  778. def get_bprop_batch_to_space(self):
  779. """Generate bprop for BatchToSpace"""
  780. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  781. def bprop(x, out, dout):
  782. dx = batch_to_space_grad(dout)
  783. return (dx,)
  784. return bprop
  785. @bprop_getters.register(P.SpaceToBatchND)
  786. def get_bprop_space_to_batch_nd(self):
  787. """Generate bprop for SpaceToBatchND"""
  788. space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
  789. def bprop(x, out, dout):
  790. dx = space_to_batch_nd_grad(dout)
  791. return (dx,)
  792. return bprop
  793. @bprop_getters.register(P.BatchToSpaceND)
  794. def get_bprop_batch_to_space_nd(self):
  795. """Generate bprop for BatchToSpaceND"""
  796. batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
  797. def bprop(x, out, dout):
  798. dx = batch_to_space_nd_grad(dout)
  799. return (dx,)
  800. return bprop
  801. @bprop_getters.register(P.BroadcastTo)
  802. def get_bprop_broadcast_to(self):
  803. """Generate bprop for BroadcastTo"""
  804. reduce_keep_dim = P.ReduceSum(keep_dims=True)
  805. def bprop(x, out, dout):
  806. x_shape = shape_op(x)
  807. dout_shape = shape_op(dout)
  808. broadcast_shape = shape_op(out)
  809. if x_shape == dout_shape:
  810. return (dout,)
  811. _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
  812. reduced_grad = reduce_keep_dim(dout, reduction_axes)
  813. dx = reshape(reduced_grad, x_shape)
  814. return (dx,)
  815. return bprop
  816. @bprop_getters.register(P.ReverseSequence)
  817. def get_bprop_reverse_sequence(self):
  818. """Generate bprop for ReverseSequence"""
  819. reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
  820. def bprop(x, seq_lengths, out, dout):
  821. dx = reverse_sequence_grad(dout, seq_lengths)
  822. return dx, zeros_like(seq_lengths)
  823. return bprop
  824. @bprop_getters.register(P.TransShape)
  825. def get_bprop_trans_shape(self):
  826. """Generate bprop for TransShape"""
  827. op = P.TransShape()
  828. def bprop(x, shape, out, dout):
  829. dx = op(dout, shape_op(x))
  830. return (dx, zeros_like(shape))
  831. return bprop
  832. @bprop_getters.register(P.Unique)
  833. def get_bprop_unique(self):
  834. """Generate bprop for Unique"""
  835. op = G.UniqueGrad()
  836. def bprop(x, out, dout):
  837. dx = op(dout, out)
  838. return (dx,)
  839. return bprop
  840. @bprop_getters.register(P.MaskedSelect)
  841. def get_bprop_masked_select(self):
  842. """Generate bprop for MaskedSelect"""
  843. op = G.MaskedSelectGrad()
  844. def bprop(x, mask, out, dout):
  845. dx = op(x, mask, dout)
  846. return (dx, zeros_like(mask))
  847. return bprop