You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. import numpy as np
  17. import mindspore as ms
  18. from mindspore.ops import composite as C
  19. from mindspore.common.tensor import Tensor
  20. from .. import operations as P
  21. from ..operations import _grad_ops as G
  22. from ..operations import _inner_ops as inner
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..functional import broadcast_gradient_args
  25. from .. import functional as F
  26. from .grad_base import bprop_getters
  27. from ..primitive import constexpr
  28. from ... import context
  29. from ...common import dtype as mstype
  30. from ...common.tensor import RowTensor
  31. reduce_sum = P.ReduceSum()
  32. unsorted_segment_sum = P.UnsortedSegmentSum()
  33. transpose = P.Transpose()
  34. shape_op = P.Shape()
  35. dyn_shape_op = P.DynamicShape()
  36. reshape = P.Reshape()
  37. size_op = P.Size()
  38. invert_permutation = P.InvertPermutation()
  39. logical_and = P.LogicalAnd()
  40. is_sub_class = P.IsSubClass()
  41. @bprop_getters.register(P.Fill)
  42. def get_bprop_fill(self):
  43. """Generate bprop for Fill"""
  44. def bprop(dtype, dims, x, out, dout):
  45. return zeros_like(dims), zeros_like(x)
  46. return bprop
  47. @bprop_getters.register(P.Ones)
  48. def get_bprop_ones(self):
  49. """Generate bprop for Ones"""
  50. def bprop(dims, dtype, out, dout):
  51. return zeros_like(dims)
  52. return bprop
  53. @bprop_getters.register(P.Zeros)
  54. def get_bprop_zeros(self):
  55. """Generate bprop for Zeros"""
  56. def bprop(dims, dtype, out, dout):
  57. return zeros_like(dims)
  58. return bprop
  59. @bprop_getters.register(P.DType)
  60. def get_bprop_dtype(self):
  61. """Generate bprop for DType"""
  62. def bprop(x, out, dout):
  63. return (zeros_like(x),)
  64. return bprop
  65. dout_cast = C.MultitypeFuncGraph("dout_cast")
  66. @dout_cast.register("Tensor", "Tensor")
  67. def dout_cast_tensor(dout, x):
  68. cast = P.Cast()
  69. get_dtype = P.DType()
  70. dx = cast(dout, get_dtype(x))
  71. return dx
  72. @dout_cast.register("Number", "Number")
  73. def dout_cast_number(dout, x):
  74. cast = P.Cast()
  75. get_dtype = P.DType()
  76. dx = cast(dout, get_dtype(x))
  77. return dx
  78. @dout_cast.register("RowTensor", "Tensor")
  79. def dout_cast_row_tensor(dout, x):
  80. cast = P.Cast()
  81. get_dtype = P.DType()
  82. values = cast(dout.values, get_dtype(x))
  83. return RowTensor(dout.indices, values, dout.dense_shape)
  84. @bprop_getters.register(P.Cast)
  85. def get_bprop_cast(self):
  86. """Generate bprop for Cast"""
  87. cast = P.Cast()
  88. get_dtype = P.DType()
  89. def bprop(x, t, out, dout):
  90. dx = cast(dout, get_dtype(x))
  91. return dx, zeros_like(t)
  92. def bprop_sparse(x, t, out, dout):
  93. dx = dout_cast(dout, x)
  94. return dx, zeros_like(t)
  95. if context.get_context('enable_sparse'):
  96. return bprop_sparse
  97. return bprop
  98. @bprop_getters.register(P.Shape)
  99. def get_bprop_shape(self):
  100. """Generate bprop for Shape"""
  101. def bprop(x, out, dout):
  102. return (zeros_like(x),)
  103. return bprop
  104. @bprop_getters.register(P.Split)
  105. def get_bprop_split(self):
  106. """Generate bprop for Split"""
  107. axis = self.axis
  108. def bprop(x, out, dout):
  109. concat_op = P.Concat(axis)
  110. dx = concat_op(dout)
  111. return (dx,)
  112. return bprop
  113. @bprop_getters.register(P.Rank)
  114. def get_bprop_rank(self):
  115. """Generate bprop for Rank"""
  116. def bprop(x, out, dout):
  117. return (zeros_like(x),)
  118. return bprop
  119. @bprop_getters.register(P.Reshape)
  120. def get_bprop_reshape(self):
  121. """Generate bprop for Reshape"""
  122. def bprop(x, shp, out, dout):
  123. shapex = shape_op(x)
  124. return reshape(dout, shapex), zeros_like(shp)
  125. return bprop
  126. @bprop_getters.register(P.ExpandDims)
  127. def get_bprop_expand_dims(self):
  128. """Generate bprop for ExpandDims"""
  129. def bprop(x, axis, out, dout):
  130. shapex = shape_op(x)
  131. return reshape(dout, shapex), zeros_like(axis)
  132. return bprop
  133. @bprop_getters.register(P.Squeeze)
  134. def get_bprop_squeeze(self):
  135. """Generate bprop for Squeeze"""
  136. def bprop(x, out, dout):
  137. shapex = shape_op(x)
  138. return (reshape(dout, shapex),)
  139. return bprop
  140. @bprop_getters.register(P.Flatten)
  141. def get_bprop_flatten(self):
  142. """Generate bprop for Flatten"""
  143. flatten_grad = G.FlattenGrad()
  144. def bprop(x, out, dout):
  145. dx = flatten_grad(dout, shape_op(x))
  146. return (dx,)
  147. return bprop
  148. @constexpr
  149. def _tile_shape(multiples, shapex):
  150. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  151. len_muli = len(multiples)
  152. rank = len(shapex)
  153. len_cmp = len_muli - rank
  154. max_len = max(len_muli, rank)
  155. i = 0
  156. j = 0
  157. ret = []
  158. while (i < max_len) and (j < max_len):
  159. if len_cmp == 0:
  160. ret.append(multiples[i])
  161. ret.append(shapex[j])
  162. i += 1
  163. j += 1
  164. elif len_cmp > 0:
  165. ret.append(multiples[i])
  166. ret.append(1)
  167. i += 1
  168. len_cmp -= 1
  169. else:
  170. ret.append(1)
  171. ret.append(shapex[j])
  172. len_cmp += 1
  173. return tuple(ret)
  174. @bprop_getters.register(P.Tile)
  175. def get_bprop_tile(self):
  176. """Generate bprop for Tile"""
  177. def bprop(x, multiples, out, dout):
  178. shapex = shape_op(x)
  179. r_shape = _tile_shape(multiples, shapex)
  180. # 0 represents the start index, and 2 represents the step
  181. axis = F.make_range(0, len(r_shape), 2)
  182. dx = reduce_sum(reshape(dout, r_shape), axis)
  183. dx = reshape(dx, shapex)
  184. return dx, zeros_like(multiples)
  185. return bprop
  186. @bprop_getters.register(P.EmbeddingLookup)
  187. def get_bprop_embedding_lookup(self):
  188. """Generate bprop for EmbeddingLookup"""
  189. sub_op = P.Sub()
  190. reshape_op = P.Reshape()
  191. def bprop_sparse(x, indices, offset, out, dout):
  192. x_shp = shape_op(x)
  193. new_indices = sub_op(indices, offset)
  194. # Reshape the 'new_indices'
  195. new_indices_shape_changed = (size_op(new_indices),)
  196. new_indices = reshape_op(new_indices, new_indices_shape_changed)
  197. x_shp_tail = x_shp[1:]
  198. actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
  199. # Reshape the 'actual_dout' on device
  200. actual_dout = reshape_op(dout, actual_dout_shape_changed)
  201. return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
  202. return bprop_sparse
  203. @constexpr
  204. def make_begin(shp):
  205. begin = tuple([0 for _ in shp])
  206. return begin
  207. @bprop_getters.register(P.Padding)
  208. def get_bprop_padding(self):
  209. """Grad definition for `Padding` operation."""
  210. def bprop(x, out, dout):
  211. shp = shape_op(x)
  212. begin = make_begin(shp)
  213. dx = P.Slice()(dout, begin, shp)
  214. return (dx,)
  215. return bprop
  216. @bprop_getters.register(P.Transpose)
  217. def get_bprop_transpose(self):
  218. """Generate bprop for Transpose"""
  219. def bprop(x, perm, out, dout):
  220. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  221. return bprop
  222. @constexpr
  223. def _concat_grad_uniform(input_shapes, input_nums):
  224. """Helper function for bprop of Concat"""
  225. is_uniform = True
  226. for i in range(1, input_nums):
  227. if input_shapes[i - 1] != input_shapes[i]:
  228. is_uniform = False
  229. break
  230. return is_uniform
  231. @bprop_getters.register(P.Concat)
  232. def get_bprop_concat(self):
  233. """Generate bprop for Concat"""
  234. axis = self.axis
  235. def bprop(x, out, dout):
  236. dx = ()
  237. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  238. input_nums = F.tuple_len(x)
  239. input_shapes = ()
  240. for i in range(input_nums):
  241. input_shapes = input_shapes + (shape_op(x[i]),)
  242. is_uniform = _concat_grad_uniform(input_shapes, input_nums)
  243. if is_uniform:
  244. dx = P.Split(axis, input_nums)(dout)
  245. else:
  246. for i in range(input_nums):
  247. slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
  248. dx = dx + (slice_out,)
  249. return (dx,)
  250. return bprop
  251. @constexpr
  252. def _slice_grad_pad(begins, sizes, shapes):
  253. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  254. return pads
  255. @bprop_getters.register(P.Slice)
  256. def get_bprop_slice(self):
  257. """Generate bprop for Slice"""
  258. def bprop(x, begin, size, out, dout):
  259. dx = G.SliceGrad()(dout, x, begin, size)
  260. return (dx, zeros_like(begin), zeros_like(size))
  261. return bprop
  262. @constexpr
  263. def _generate_shape_index(out_shape, indices_shape, axis):
  264. out_rank = len(out_shape)
  265. ind_rank = len(indices_shape)
  266. if axis < 0:
  267. axis += out_rank - ind_rank + 1
  268. perm_part1 = tuple(range(axis, axis + ind_rank))
  269. index = tuple(range(out_rank))
  270. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  271. return perm
  272. @constexpr
  273. def _generate_inverse_index(x_shape, axis):
  274. x_rank = len(x_shape)
  275. index = tuple(range(x_rank))
  276. if axis < 0:
  277. axis += x_rank
  278. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  279. return perm
  280. @constexpr
  281. def _regenerate_output_shape(x_shp, ind_shp, axis):
  282. rank = len(x_shp)
  283. if axis < 0:
  284. axis += rank
  285. out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:]
  286. return out_shape
  287. @bprop_getters.register(P.Gather)
  288. @bprop_getters.register(P.GatherV2)
  289. def get_bprop_gather_v2(self):
  290. """Generate bprop for GatherV2"""
  291. def bprop(x, indices, axis, out, dout):
  292. orig_indices = indices
  293. if F.rank(dout) == 0:
  294. dout = P.ExpandDims()(dout, -1)
  295. if F.rank(indices) == 0:
  296. indices = P.ExpandDims()(indices, -1)
  297. x_shp = shape_op(x)
  298. ind_shp = shape_op(indices)
  299. out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
  300. dout = reshape(dout, out_shp)
  301. x_shp = shape_op(x)
  302. out_shp = shape_op(dout)
  303. ind_shp = shape_op(indices)
  304. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  305. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  306. values_transpose = transpose(dout, perm_1)
  307. if -1 in shape_op(x):
  308. params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
  309. else:
  310. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  311. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  312. perm_2 = _generate_inverse_index(x_shp, axis)
  313. params_grad = transpose(params_grad, perm_2)
  314. return params_grad, zeros_like(orig_indices), zeros_like(axis)
  315. return bprop
  316. @bprop_getters.register(P.GatherD)
  317. def get_bprop_gather_d(self):
  318. """Generate bprop for GatherD"""
  319. def bprop(x, dim, index, out, dout):
  320. x_shp = shape_op(x)
  321. dx = G.GatherDGrad(dim, x_shp)(index, dout)
  322. return dx, zeros_like(dim), zeros_like(index)
  323. return bprop
  324. @bprop_getters.register(P.SparseGatherV2)
  325. def get_bprop_sparse_gather_v2(self):
  326. """Generate bprop for SparseGatherV2"""
  327. def bprop(x, indices, axis, out, dout):
  328. x_shp = shape_op(x)
  329. if axis == 0:
  330. indices_size = (size_op(indices),)
  331. if len(x_shp) <= 1:
  332. x_tail_shp = ()
  333. else:
  334. x_tail_shp = x_shp[1:]
  335. values_shape = indices_size + x_tail_shp
  336. values = reshape(dout, values_shape)
  337. indices_new = reshape(indices, indices_size)
  338. return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
  339. if F.rank(dout) == 0:
  340. dout = P.ExpandDims()(dout, -1)
  341. if F.rank(indices) == 0:
  342. indices = P.ExpandDims()(indices, -1)
  343. out_shp = shape_op(dout)
  344. ind_shp = shape_op(indices)
  345. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  346. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  347. values_transpose = transpose(dout, perm_1)
  348. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  349. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  350. perm_2 = _generate_inverse_index(x_shp, axis)
  351. params_grad = transpose(params_grad, perm_2)
  352. return params_grad, zeros_like(indices), zeros_like(axis)
  353. return bprop
  354. @constexpr
  355. def _range_op(start, limit, delta, dtype):
  356. """helper function for grad of Sort"""
  357. output_tensor = Tensor(list(range(start, limit, delta)), dtype)
  358. return output_tensor
  359. @constexpr
  360. def _get_1d_shape(in_shape):
  361. """helper function for grad of Sort"""
  362. out_shape = 1
  363. for i in in_shape:
  364. out_shape *= i
  365. return (out_shape,)
  366. @constexpr
  367. def _get_transposition(axis, rank):
  368. """helper function for grad of Sort"""
  369. if axis < 0:
  370. axis += rank
  371. transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
  372. trans = tuple(transposition.tolist())
  373. return trans
  374. @bprop_getters.register(P.Sort)
  375. def get_bprop_sort(self):
  376. """Grad definition for `Sort` operation."""
  377. axis = self.axis
  378. descending = self.descending
  379. scatter = P.ScatterNd()
  380. expand_dims = P.ExpandDims()
  381. reshape_op = P.Reshape()
  382. dtype = P.DType()
  383. topk = P.TopK()
  384. neg = P.Neg()
  385. tranpose = P.Transpose()
  386. def bprop(input_x, out, dout):
  387. x_shape = input_x.shape
  388. k = x_shape[axis]
  389. rank = F.rank(input_x)
  390. dvalue = dout[0]
  391. if not descending:
  392. input_x = neg(input_x)
  393. dvalue = neg(dvalue)
  394. if axis == -1 or (axis + 1) == rank:
  395. transposition = None
  396. top_k_input = input_x
  397. else:
  398. transposition = _get_transposition(axis, rank)
  399. top_k_input = tranpose(input_x, transposition)
  400. _, indices = topk(top_k_input, k)
  401. ind_shape = indices.shape
  402. top_k_input_shape = top_k_input.shape
  403. in_lastdim = top_k_input_shape[-1]
  404. ind_lastdim = ind_shape[-1]
  405. ind_2d = reshape_op(indices, (-1, ind_lastdim))
  406. outer_dim = ind_2d.shape[0]
  407. # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
  408. indices_dtype = dtype(indices)
  409. range_flatten_index = _range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
  410. # expand_dims to (k, 1), then broadcast
  411. ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
  412. x_shape_1d = _get_1d_shape(top_k_input_shape)
  413. if transposition is not None:
  414. dvalue = tranpose(dvalue, invert_permutation(transposition))
  415. out_grad = reshape_op(
  416. scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
  417. dx = tranpose(out_grad, invert_permutation(transposition))
  418. else:
  419. dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
  420. if not descending:
  421. dx = neg(dx)
  422. return (dx,)
  423. return bprop
  424. @bprop_getters.register(P.Identity)
  425. def get_bprop_identity(self):
  426. """Generate bprop for Identity"""
  427. def bprop(x, out, dout):
  428. return (dout,)
  429. return bprop
  430. @bprop_getters.register(inner.Range)
  431. def get_bprop_range(self):
  432. """Generate bprop for Range"""
  433. def bprop(x, out, dout):
  434. return (zeros_like(x),)
  435. return bprop
  436. @bprop_getters.register(P.Pack)
  437. @bprop_getters.register(P.Stack)
  438. def get_bprop_stack(self):
  439. """Generate bprop for Stack"""
  440. axis = self.axis
  441. def bprop(x, out, dout):
  442. stack_grad = P.Unstack(axis)
  443. out = stack_grad(dout)
  444. if is_sub_class(F.typeof(x), ms.list_):
  445. ret = []
  446. for item in out:
  447. ret.append(item)
  448. return (ret,)
  449. return (out,)
  450. return bprop
  451. @bprop_getters.register(P.ReverseV2)
  452. def get_bprop_reverse_v2(self):
  453. """Generate bprop for ReverseV2"""
  454. axis = self.axis
  455. def bprop(x, out, dout):
  456. reverse_grad = P.ReverseV2(axis)
  457. dx = reverse_grad(dout)
  458. return (dx,)
  459. return bprop
  460. @bprop_getters.register(P.Unstack)
  461. def get_bprop_unstack(self):
  462. """Generate bprop for Unstack"""
  463. axis = self.axis
  464. def bprop(x, out, dout):
  465. unstack_grad = P.Stack(axis)
  466. out = unstack_grad(dout)
  467. return (out,)
  468. return bprop
  469. @bprop_getters.register(P.StridedSlice)
  470. def get_bprop_strided_slice(self):
  471. """Generate bprop for StridedSlice"""
  472. input_grad = G.StridedSliceGrad(self.begin_mask,
  473. self.end_mask,
  474. self.ellipsis_mask,
  475. self.new_axis_mask,
  476. self.shrink_axis_mask)
  477. def bprop(x, begin, end, strides, out, dout):
  478. dx = input_grad(dout, shape_op(x), begin, end, strides)
  479. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  480. return bprop
  481. @bprop_getters.register(P.Eye)
  482. def get_bprop_eye(self):
  483. """Generate bprop for Eye"""
  484. def bprop(n, m, t, out, dout):
  485. return zeros_like(n), zeros_like(m), zeros_like(t)
  486. return bprop
  487. @bprop_getters.register(P.Select)
  488. def get_bprop_select(self):
  489. """Generate bprop for Select"""
  490. select = P.Select()
  491. def bprop(cond, x, y, out, dout):
  492. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  493. return bprop
  494. @bprop_getters.register(P.OnesLike)
  495. def get_bprop_oneslike(self):
  496. """Generate bprop for OnesLike"""
  497. def bprop(x, out, dout):
  498. return (zeros_like(x),)
  499. return bprop
  500. @bprop_getters.register(P.ZerosLike)
  501. def get_bprop_zeroslike(self):
  502. """Generate bprop for OnesLike"""
  503. def bprop(x, out, dout):
  504. return (zeros_like(x),)
  505. return bprop
  506. @bprop_getters.register(P.ResizeNearestNeighbor)
  507. def get_bprop_resize_nearest_neighbor(self):
  508. """Generate bprop for ResizeNearestNeighbor"""
  509. op = G.ResizeNearestNeighborGrad(self.align_corners)
  510. def bprop(inputs, out, dout):
  511. shp = shape_op(inputs)
  512. # 2 and 3 represent the height and width
  513. shp = (shp[2], shp[3])
  514. return (op(dout, shp),)
  515. return bprop
  516. @bprop_getters.register(P.GatherNd)
  517. def get_bprop_gather_nd(self):
  518. """Generate bprop for GatherNd"""
  519. op = P.ScatterNd()
  520. def bprop(x, indices, out, dout):
  521. shp = shape_op(x)
  522. return op(indices, dout, shp), zeros_like(indices)
  523. return bprop
  524. @bprop_getters.register(P.ScatterNd)
  525. def get_bprop_scatter_nd(self):
  526. """Generate bprop for ScatterNd"""
  527. op = P.GatherNd()
  528. def bprop(indices, x, shape, out, dout):
  529. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  530. return bprop
  531. @bprop_getters.register(P.ScatterNdUpdate)
  532. def get_bprop_scatter_nd_update(self):
  533. """Generate bprop for ScatterNdUpdate"""
  534. op = P.GatherNd()
  535. def bprop(x, indices, update, out, dout):
  536. return dout, zeros_like(indices), op(dout, indices)
  537. return bprop
  538. @bprop_getters.register(P.ScatterNonAliasingAdd)
  539. def get_bprop_scatter_non_aliasing_add_update(self):
  540. """Generate bprop for ScatterNonAliasingAdd"""
  541. op = P.GatherNd()
  542. def bprop(x, indices, update, out, dout):
  543. return dout, zeros_like(indices), op(dout, indices)
  544. return bprop
  545. @bprop_getters.register(P.TensorScatterUpdate)
  546. def get_bprop_tensor_scatter_update(self):
  547. """Generate bprop for TensorScatterUpdate"""
  548. gather_nd = P.GatherNd()
  549. tensor_scatter_update = P.TensorScatterUpdate()
  550. def bprop(x, indices, update, out, dout):
  551. x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
  552. update_grad = gather_nd(dout, indices)
  553. return x_grad, zeros_like(indices), update_grad
  554. return bprop
  555. @bprop_getters.register(P.ScatterMax)
  556. def get_bprop_scatter_max(self):
  557. """Generate bprop for ScatterMax"""
  558. gather = P.Gather()
  559. def bprop(x, indices, update, out, dout):
  560. return dout, zeros_like(indices), gather(dout, indices, 0)
  561. return bprop
  562. @bprop_getters.register(P.Argmax)
  563. def get_bprop_argmax(self):
  564. """Generate bprop for Argmax"""
  565. def bprop(x, out, dout):
  566. return (zeros_like(x),)
  567. return bprop
  568. @bprop_getters.register(P.Argmin)
  569. def get_bprop_argmin(self):
  570. """Generate bprop for Argmin"""
  571. def bprop(x, out, dout):
  572. return (zeros_like(x),)
  573. return bprop
  574. @bprop_getters.register(P.SpaceToDepth)
  575. def get_bprop_space_to_depth(self):
  576. """Generate bprop for SpaceToDepth"""
  577. op = P.DepthToSpace(self.block_size)
  578. def bprop(x, out, dout):
  579. return (op(dout),)
  580. return bprop
  581. @bprop_getters.register(P.DepthToSpace)
  582. def get_bprop_depth_to_space(self):
  583. """Generate bprop for DepthToSpace"""
  584. op = P.SpaceToDepth(self.block_size)
  585. def bprop(x, out, dout):
  586. return (op(dout),)
  587. return bprop
  588. @bprop_getters.register(P.Diag)
  589. def get_bprop_diag(self):
  590. """Generate bprop for Diag"""
  591. op = P.DiagPart()
  592. def bprop(x, out, dout):
  593. return (op(dout),)
  594. return bprop
  595. @bprop_getters.register(P.DiagPart)
  596. def get_bprop_diag_part(self):
  597. """Generate bprop for DiagPart"""
  598. op = P.Diag()
  599. def bprop(x, out, dout):
  600. return (op(dout),)
  601. return bprop
  602. def _gather_drop_negatives(params,
  603. ids,
  604. zero_clipped_indices=None,
  605. is_positive=None):
  606. """Helper function for unsorted segment ops."""
  607. maximum = P.Maximum()
  608. gather = P.Gather()
  609. greater_equal = P.GreaterEqual()
  610. rank = P.Rank()
  611. fill = P.Fill()
  612. select = P.Select()
  613. if zero_clipped_indices is None:
  614. zero_clipped_indices = maximum(ids, zeros_like(ids))
  615. gathered = gather(params, zero_clipped_indices, 0)
  616. if is_positive is None:
  617. is_positive = greater_equal(ids, 0)
  618. is_positive_shape = shape_op(is_positive)
  619. broadcastable_shape = is_positive_shape
  620. for _ in range(rank(gathered) - rank(is_positive)):
  621. broadcastable_shape += (1,)
  622. is_positive = reshape(is_positive, broadcastable_shape)
  623. gathered_shape = shape_op(gathered)
  624. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  625. zero_slice = zeros_like(gathered)
  626. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  627. def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
  628. """Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
  629. equal = P.Equal()
  630. cast = P.Cast()
  631. divide = P.RealDiv()
  632. get_dtype = P.DType()
  633. select = P.Select()
  634. gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
  635. is_selected = equal(x, gathered_outputs)
  636. is_selected = logical_and(is_selected, is_positive)
  637. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  638. segment_ids, num_segments)
  639. weighted_grads = divide(dout, num_selected)
  640. gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
  641. zero_clipped_indices, is_positive)
  642. zeros = zeros_like(gathered_grads)
  643. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  644. @bprop_getters.register(P.UnsortedSegmentSum)
  645. def get_bprop_unsorted_segment_sum(self):
  646. """Generate bprop for UnsortedSegmentSum"""
  647. def bprop(x, segment_ids, num_segments, out, dout):
  648. return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
  649. zeros_like(num_segments)
  650. return bprop
  651. @bprop_getters.register(P.UnsortedSegmentMin)
  652. def get_bprop_unsorted_segment_min(self):
  653. """Generate bprop for UnsortedSegmentMin"""
  654. def bprop(x, segment_ids, num_segments, out, dout):
  655. return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
  656. return bprop
  657. @bprop_getters.register(P.UnsortedSegmentMax)
  658. def get_bprop_unsorted_segment_max(self):
  659. """Generate bprop for UnsortedSegmentMax"""
  660. def bprop(x, segment_ids, num_segments, out, dout):
  661. return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
  662. return bprop
  663. @bprop_getters.register(P.UnsortedSegmentProd)
  664. def get_bprop_unsorted_segment_prod(self):
  665. """Generate bprop for UnsortedSegmentProd"""
  666. equal = P.Equal()
  667. cast = P.Cast()
  668. select = P.Select()
  669. gather = P.Gather()
  670. greater = P.Greater()
  671. ones_like = P.OnesLike()
  672. maximum = P.Maximum()
  673. unsorted_segment_prod = P.UnsortedSegmentProd()
  674. def bprop(x, segment_ids, num_segments, out, dout):
  675. is_zero = equal(x, 0)
  676. num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
  677. grad = select(greater(num_zero, 1), zeros_like(dout), dout)
  678. non_zero_data = select(is_zero, ones_like(x), x)
  679. non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
  680. zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
  681. gathered_prod = gather(out, zero_clipped_indices, 0)
  682. gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
  683. prod_divided_by_x = gathered_prod / x
  684. partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
  685. gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
  686. dx = gathered_grad * partial_derivative
  687. return dx, zeros_like(segment_ids), zeros_like(num_segments)
  688. return bprop
  689. @bprop_getters.register(P.SpaceToBatch)
  690. def get_bprop_space_to_batch(self):
  691. """Generate bprop for SpaceToBatch"""
  692. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  693. def bprop(x, out, dout):
  694. dx = space_to_batch_grad(dout)
  695. return (dx,)
  696. return bprop
  697. @bprop_getters.register(P.BatchToSpace)
  698. def get_bprop_batch_to_space(self):
  699. """Generate bprop for BatchToSpace"""
  700. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  701. def bprop(x, out, dout):
  702. dx = batch_to_space_grad(dout)
  703. return (dx,)
  704. return bprop
  705. @bprop_getters.register(P.SpaceToBatchND)
  706. def get_bprop_space_to_batch_nd(self):
  707. """Generate bprop for SpaceToBatchND"""
  708. space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
  709. def bprop(x, out, dout):
  710. dx = space_to_batch_nd_grad(dout)
  711. return (dx,)
  712. return bprop
  713. @bprop_getters.register(P.BatchToSpaceND)
  714. def get_bprop_batch_to_space_nd(self):
  715. """Generate bprop for BatchToSpaceND"""
  716. batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
  717. def bprop(x, out, dout):
  718. dx = batch_to_space_nd_grad(dout)
  719. return (dx,)
  720. return bprop
  721. @bprop_getters.register(P.BroadcastTo)
  722. def get_bprop_broadcast_to(self):
  723. """Generate bprop for BroadcastTo"""
  724. reduce_keep_dim = P.ReduceSum(keep_dims=True)
  725. def bprop(x, out, dout):
  726. x_shape = shape_op(x)
  727. dout_shape = shape_op(dout)
  728. broadcast_shape = shape_op(out)
  729. if x_shape == dout_shape:
  730. return (dout,)
  731. _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
  732. reduced_grad = reduce_keep_dim(dout, reduction_axes)
  733. dx = reshape(reduced_grad, x_shape)
  734. return (dx,)
  735. return bprop
  736. @bprop_getters.register(P.ReverseSequence)
  737. def get_bprop_reverse_sequence(self):
  738. """Generate bprop for ReverseSequence"""
  739. reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
  740. def bprop(x, seq_lengths, out, dout):
  741. dx = reverse_sequence_grad(dout, seq_lengths)
  742. return dx, zeros_like(seq_lengths)
  743. return bprop
  744. @bprop_getters.register(P.TransShape)
  745. def get_bprop_trans_shape(self):
  746. """Generate bprop for TransShape"""
  747. op = P.TransShape()
  748. def bprop(x, shape, out, dout):
  749. dx = op(dout, shape_op(x))
  750. return (dx, zeros_like(shape))
  751. return bprop
  752. @bprop_getters.register(P.Unique)
  753. def get_bprop_unique(self):
  754. """Generate bprop for Unique"""
  755. op = G.UniqueGrad()
  756. def bprop(x, out, dout):
  757. dx = op(dout, out)
  758. return (dx,)
  759. return bprop