You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 24 kB

5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. import mindspore as ms
  17. from mindspore.ops import composite as C
  18. from .. import operations as P
  19. from ..operations import _grad_ops as G
  20. from ..operations import _inner_ops as inner
  21. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  22. from ..functional import broadcast_gradient_args
  23. from .. import functional as F
  24. from .grad_base import bprop_getters
  25. from ..primitive import constexpr
  26. from ... import context
  27. from ...common import dtype as mstype
  28. from ...common.tensor import RowTensor
  29. reduce_sum = P.ReduceSum()
  30. unsorted_segment_sum = P.UnsortedSegmentSum()
  31. transpose = P.Transpose()
  32. shape_op = P.Shape()
  33. reshape = P.Reshape()
  34. size_op = P.Size()
  35. invert_permutation = P.InvertPermutation()
  36. logical_and = P.LogicalAnd()
  37. is_sub_class = P.IsSubClass()
  38. @bprop_getters.register(P.Fill)
  39. def get_bprop_fill(self):
  40. """Generate bprop for Fill"""
  41. def bprop(dtype, dims, x, out, dout):
  42. return zeros_like(dims), zeros_like(x)
  43. return bprop
  44. @bprop_getters.register(P.DType)
  45. def get_bprop_dtype(self):
  46. """Generate bprop for DType"""
  47. def bprop(x, out, dout):
  48. return (zeros_like(x),)
  49. return bprop
  50. dout_cast = C.MultitypeFuncGraph("dout_cast")
  51. @dout_cast.register("Tensor", "Tensor")
  52. def dout_cast_tensor(dout, x):
  53. cast = P.Cast()
  54. get_dtype = P.DType()
  55. dx = cast(dout, get_dtype(x))
  56. return dx
  57. @dout_cast.register("Number", "Number")
  58. def dout_cast_number(dout, x):
  59. cast = P.Cast()
  60. get_dtype = P.DType()
  61. dx = cast(dout, get_dtype(x))
  62. return dx
  63. @dout_cast.register("RowTensor", "Tensor")
  64. def dout_cast_row_tensor(dout, x):
  65. cast = P.Cast()
  66. get_dtype = P.DType()
  67. values = cast(dout.values, get_dtype(x))
  68. return RowTensor(dout.indices, values, dout.dense_shape)
  69. @bprop_getters.register(P.Cast)
  70. def get_bprop_cast(self):
  71. """Generate bprop for Cast"""
  72. cast = P.Cast()
  73. get_dtype = P.DType()
  74. def bprop(x, t, out, dout):
  75. dx = cast(dout, get_dtype(x))
  76. return dx, zeros_like(t)
  77. def bprop_sparse(x, t, out, dout):
  78. dx = dout_cast(dout, x)
  79. return dx, zeros_like(t)
  80. if context.get_context('enable_sparse'):
  81. return bprop_sparse
  82. return bprop
  83. @bprop_getters.register(P.Shape)
  84. def get_bprop_shape(self):
  85. """Generate bprop for Shape"""
  86. def bprop(x, out, dout):
  87. return (zeros_like(x),)
  88. return bprop
  89. @bprop_getters.register(P.Split)
  90. def get_bprop_split(self):
  91. """Generate bprop for Split"""
  92. axis = self.axis
  93. def bprop(x, out, dout):
  94. concat_op = P.Concat(axis)
  95. dx = concat_op(dout)
  96. return (dx,)
  97. return bprop
  98. @bprop_getters.register(P.Rank)
  99. def get_bprop_rank(self):
  100. """Generate bprop for Rank"""
  101. def bprop(x, out, dout):
  102. return (zeros_like(x),)
  103. return bprop
  104. @bprop_getters.register(P.Reshape)
  105. def get_bprop_reshape(self):
  106. """Generate bprop for Reshape"""
  107. def bprop(x, shp, out, dout):
  108. shapex = shape_op(x)
  109. return reshape(dout, shapex), zeros_like(shp)
  110. return bprop
  111. @bprop_getters.register(P.ExpandDims)
  112. def get_bprop_expand_dims(self):
  113. """Generate bprop for ExpandDims"""
  114. def bprop(x, axis, out, dout):
  115. shapex = shape_op(x)
  116. return reshape(dout, shapex), zeros_like(axis)
  117. return bprop
  118. @bprop_getters.register(P.Squeeze)
  119. def get_bprop_squeeze(self):
  120. """Generate bprop for Squeeze"""
  121. def bprop(x, out, dout):
  122. shapex = shape_op(x)
  123. return (reshape(dout, shapex),)
  124. return bprop
  125. @bprop_getters.register(P.Flatten)
  126. def get_bprop_flatten(self):
  127. """Generate bprop for Flatten"""
  128. flatten_grad = G.FlattenGrad()
  129. def bprop(x, out, dout):
  130. dx = flatten_grad(dout, shape_op(x))
  131. return (dx,)
  132. return bprop
  133. @constexpr
  134. def _tile_shape(multiples, shapex):
  135. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  136. len_muli = len(multiples)
  137. rank = len(shapex)
  138. len_cmp = len_muli - rank
  139. max_len = max(len_muli, rank)
  140. i = 0
  141. j = 0
  142. ret = []
  143. while (i < max_len) and (j < max_len):
  144. if len_cmp == 0:
  145. ret.append(multiples[i])
  146. ret.append(shapex[j])
  147. i += 1
  148. j += 1
  149. elif len_cmp > 0:
  150. ret.append(multiples[i])
  151. ret.append(1)
  152. i += 1
  153. len_cmp -= 1
  154. else:
  155. ret.append(1)
  156. ret.append(shapex[j])
  157. len_cmp += 1
  158. return tuple(ret)
  159. @bprop_getters.register(P.Tile)
  160. def get_bprop_tile(self):
  161. """Generate bprop for Tile"""
  162. def bprop(x, multiples, out, dout):
  163. shapex = shape_op(x)
  164. r_shape = _tile_shape(multiples, shapex)
  165. # 0 represents the start index, and 2 represents the step
  166. axis = F.make_range(0, len(r_shape), 2)
  167. dx = reduce_sum(reshape(dout, r_shape), axis)
  168. dx = reshape(dx, shapex)
  169. return dx, zeros_like(multiples)
  170. return bprop
  171. @bprop_getters.register(P.EmbeddingLookup)
  172. def get_bprop_embedding_lookup(self):
  173. """Generate bprop for EmbeddingLookup"""
  174. sub_op = P.Sub()
  175. reshape_op = P.Reshape()
  176. def bprop_sparse(x, indices, offset, out, dout):
  177. x_shp = shape_op(x)
  178. new_indices = sub_op(indices, offset)
  179. # Reshape the 'new_indices'
  180. new_indices_shape_changed = (size_op(new_indices),)
  181. new_indices = reshape_op(new_indices, new_indices_shape_changed)
  182. x_shp_tail = x_shp[1:]
  183. actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
  184. # Reshape the 'actual_dout' on device
  185. actual_dout = reshape_op(dout, actual_dout_shape_changed)
  186. return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
  187. return bprop_sparse
  188. @constexpr
  189. def make_begin(shp):
  190. begin = tuple([0 for _ in shp])
  191. return begin
  192. @bprop_getters.register(P.Padding)
  193. def get_bprop_padding(self):
  194. """Grad definition for `Padding` operation."""
  195. def bprop(x, out, dout):
  196. shp = shape_op(x)
  197. begin = make_begin(shp)
  198. dx = P.Slice()(dout, begin, shp)
  199. return (dx,)
  200. return bprop
  201. @bprop_getters.register(P.Transpose)
  202. def get_bprop_transpose(self):
  203. """Generate bprop for Transpose"""
  204. def bprop(x, perm, out, dout):
  205. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  206. return bprop
  207. @constexpr
  208. def _concat_grad_uniform(input_shapes, input_nums):
  209. """Helper function for bprop of Concat"""
  210. is_uniform = True
  211. for i in range(1, input_nums):
  212. if input_shapes[i-1] != input_shapes[i]:
  213. is_uniform = False
  214. break
  215. return is_uniform
  216. @bprop_getters.register(P.Concat)
  217. def get_bprop_concat(self):
  218. """Generate bprop for Concat"""
  219. axis = self.axis
  220. is_ascend = context.get_context('device_target') == "Ascend"
  221. def bprop(x, out, dout):
  222. dx = ()
  223. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  224. input_nums = F.tuple_len(x)
  225. input_shapes = ()
  226. for i in range(input_nums):
  227. input_shapes = input_shapes + (shape_op(x[i]),)
  228. is_uniform = _concat_grad_uniform(input_shapes, input_nums)
  229. if is_uniform and is_ascend:
  230. dx = P.Split(axis, input_nums)(dout)
  231. else:
  232. for i in range(input_nums):
  233. slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
  234. dx = dx + (slice_out,)
  235. return (dx,)
  236. return bprop
  237. @constexpr
  238. def _slice_grad_pad(begins, sizes, shapes):
  239. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  240. return pads
  241. @bprop_getters.register(P.Slice)
  242. def get_bprop_slice(self):
  243. """Generate bprop for Slice"""
  244. def bprop(x, begin, size, out, dout):
  245. dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
  246. return (dx, zeros_like(begin), zeros_like(size))
  247. def bprop_grad(x, begin, size, out, dout):
  248. dx = dx = G.SliceGrad()(dout, x, begin, size)
  249. return (dx, zeros_like(begin), zeros_like(size))
  250. if context.get_context('device_target') == "GPU" or context.get_context('device_target') == "CPU":
  251. return bprop_grad
  252. return bprop
  253. @constexpr
  254. def _generate_shape_index(out_shape, indices_shape, axis):
  255. out_rank = len(out_shape)
  256. ind_rank = len(indices_shape)
  257. if axis < 0:
  258. axis += out_rank - ind_rank + 1
  259. perm_part1 = tuple(range(axis, axis + ind_rank))
  260. index = tuple(range(out_rank))
  261. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  262. return perm
  263. @constexpr
  264. def _generate_inverse_index(x_shape, axis):
  265. x_rank = len(x_shape)
  266. index = tuple(range(x_rank))
  267. if axis < 0:
  268. axis += x_rank
  269. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  270. return perm
  271. @bprop_getters.register(P.GatherV2)
  272. def get_bprop_gather_v2(self):
  273. """Generate bprop for GatherV2"""
  274. def bprop(x, indices, axis, out, dout):
  275. if F.rank(dout) == 0:
  276. dout = P.ExpandDims()(dout, -1)
  277. if F.rank(indices) == 0:
  278. indices = P.ExpandDims()(indices, -1)
  279. x_shp = shape_op(x)
  280. out_shp = shape_op(dout)
  281. ind_shp = shape_op(indices)
  282. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  283. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  284. values_transpose = transpose(dout, perm_1)
  285. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  286. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  287. perm_2 = _generate_inverse_index(x_shp, axis)
  288. params_grad = transpose(params_grad, perm_2)
  289. return params_grad, zeros_like(indices), zeros_like(axis)
  290. return bprop
  291. @bprop_getters.register(P.SparseGatherV2)
  292. def get_bprop_sparse_gather_v2(self):
  293. """Generate bprop for SparseGatherV2"""
  294. def bprop(x, indices, axis, out, dout):
  295. x_shp = shape_op(x)
  296. if axis == 0:
  297. indices_size = (size_op(indices),)
  298. x_tail_shp = x_shp[1:]
  299. values_shape = indices_size + x_tail_shp
  300. values = reshape(dout, values_shape)
  301. indices_new = reshape(indices, indices_size)
  302. return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
  303. if F.rank(dout) == 0:
  304. dout = P.ExpandDims()(dout, -1)
  305. if F.rank(indices) == 0:
  306. indices = P.ExpandDims()(indices, -1)
  307. out_shp = shape_op(dout)
  308. ind_shp = shape_op(indices)
  309. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  310. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  311. values_transpose = transpose(dout, perm_1)
  312. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  313. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  314. perm_2 = _generate_inverse_index(x_shp, axis)
  315. params_grad = transpose(params_grad, perm_2)
  316. return params_grad, zeros_like(indices), zeros_like(axis)
  317. return bprop
  318. @bprop_getters.register(inner.Range)
  319. def get_bprop_range(self):
  320. """Generate bprop for Range"""
  321. def bprop(x, out, dout):
  322. return (zeros_like(x),)
  323. return bprop
  324. @bprop_getters.register(P.Pack)
  325. def get_bprop_pack(self):
  326. """Generate bprop for Pack"""
  327. axis = self.axis
  328. def bprop(x, out, dout):
  329. pack_grad = P.Unpack(axis)
  330. out = pack_grad(dout)
  331. if is_sub_class(F.typeof(x), ms.list_):
  332. ret = []
  333. for item in out:
  334. ret.append(item)
  335. return (ret,)
  336. return (out,)
  337. return bprop
  338. @bprop_getters.register(P.ReverseV2)
  339. def get_bprop_reverse_v2(self):
  340. """Generate bprop for ReverseV2"""
  341. axis = self.axis
  342. def bprop(x, out, dout):
  343. reverse_grad = P.ReverseV2(axis)
  344. dx = reverse_grad(dout)
  345. return (dx,)
  346. return bprop
  347. @bprop_getters.register(P.Unpack)
  348. def get_bprop_unpack(self):
  349. """Generate bprop for Unpack"""
  350. axis = self.axis
  351. def bprop(x, out, dout):
  352. unpack_grad = P.Pack(axis)
  353. out = unpack_grad(dout)
  354. return (out,)
  355. return bprop
  356. @bprop_getters.register(P.StridedSlice)
  357. def get_bprop_strided_slice(self):
  358. """Generate bprop for StridedSlice"""
  359. input_grad = G.StridedSliceGrad(self.begin_mask,
  360. self.end_mask,
  361. self.ellipsis_mask,
  362. self.new_axis_mask,
  363. self.shrink_axis_mask)
  364. def bprop(x, begin, end, strides, out, dout):
  365. dx = input_grad(dout, shape_op(x), begin, end, strides)
  366. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  367. return bprop
  368. @bprop_getters.register(P.Eye)
  369. def get_bprop_eye(self):
  370. """Generate bprop for Eye"""
  371. def bprop(n, m, t, out, dout):
  372. return zeros_like(n), zeros_like(m), zeros_like(t)
  373. return bprop
  374. @bprop_getters.register(P.Select)
  375. def get_bprop_select(self):
  376. """Generate bprop for Select"""
  377. select = P.Select()
  378. def bprop(cond, x, y, out, dout):
  379. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  380. return bprop
  381. @bprop_getters.register(P.OnesLike)
  382. def get_bprop_oneslike(self):
  383. """Generate bprop for OnesLike"""
  384. def bprop(x, out, dout):
  385. return (zeros_like(x),)
  386. return bprop
  387. @bprop_getters.register(P.ZerosLike)
  388. def get_bprop_zeroslike(self):
  389. """Generate bprop for OnesLike"""
  390. def bprop(x, out, dout):
  391. return (zeros_like(x),)
  392. return bprop
  393. @bprop_getters.register(P.ResizeNearestNeighbor)
  394. def get_bprop_resize_nearest_neighbor(self):
  395. """Generate bprop for ResizeNearestNeighbor"""
  396. op = G.ResizeNearestNeighborGrad(self.align_corners)
  397. def bprop(inputs, out, dout):
  398. shp = shape_op(inputs)
  399. # 2 and 3 represent the height and width
  400. shp = (shp[2], shp[3])
  401. return (op(dout, shp),)
  402. return bprop
  403. @bprop_getters.register(P.GatherNd)
  404. def get_bprop_gather_nd(self):
  405. """Generate bprop for GatherNd"""
  406. op = P.ScatterNd()
  407. def bprop(x, indices, out, dout):
  408. shp = shape_op(x)
  409. return op(indices, dout, shp), zeros_like(indices)
  410. return bprop
  411. @bprop_getters.register(P.ScatterNd)
  412. def get_bprop_scatter_nd(self):
  413. """Generate bprop for ScatterNd"""
  414. op = P.GatherNd()
  415. def bprop(indices, x, shape, out, dout):
  416. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  417. return bprop
  418. @bprop_getters.register(P.ScatterNdUpdate)
  419. def get_bprop_scatter_nd_update(self):
  420. """Generate bprop for ScatterNdUpdate"""
  421. op = P.GatherNd()
  422. def bprop(x, indices, update, out, dout):
  423. return dout, zeros_like(indices), op(dout, indices)
  424. return bprop
  425. @bprop_getters.register(P.ScatterNonAliasingAdd)
  426. def get_bprop_scatter_non_aliasing_add_update(self):
  427. """Generate bprop for ScatterNonAliasingAdd"""
  428. op = P.GatherNd()
  429. def bprop(x, indices, update, out, dout):
  430. return dout, zeros_like(indices), op(dout, indices)
  431. return bprop
  432. @bprop_getters.register(P.TensorScatterUpdate)
  433. def get_bprop_tensor_scatter_update(self):
  434. """Generate bprop for TensorScatterUpdate"""
  435. gather_nd = P.GatherNd()
  436. tensor_scatter_update = P.TensorScatterUpdate()
  437. def bprop(x, indices, update, out, dout):
  438. x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
  439. update_grad = gather_nd(dout, indices)
  440. return x_grad, zeros_like(indices), update_grad
  441. return bprop
  442. @bprop_getters.register(P.ScatterMax)
  443. def get_bprop_scatter_max(self):
  444. """Generate bprop for ScatterMax"""
  445. gather = P.GatherV2()
  446. def bprop(x, indices, update, out, dout):
  447. return dout, zeros_like(indices), gather(dout, indices, 0)
  448. return bprop
  449. @bprop_getters.register(P.Argmax)
  450. def get_bprop_argmax(self):
  451. """Generate bprop for Argmax"""
  452. def bprop(x, out, dout):
  453. return (zeros_like(x),)
  454. return bprop
  455. @bprop_getters.register(P.Argmin)
  456. def get_bprop_argmin(self):
  457. """Generate bprop for Argmin"""
  458. def bprop(x, out, dout):
  459. return (zeros_like(x),)
  460. return bprop
  461. @bprop_getters.register(P.SpaceToDepth)
  462. def get_bprop_space_to_depth(self):
  463. """Generate bprop for SpaceToDepth"""
  464. op = P.DepthToSpace(self.block_size)
  465. def bprop(x, out, dout):
  466. return (op(dout),)
  467. return bprop
  468. @bprop_getters.register(P.DepthToSpace)
  469. def get_bprop_depth_to_space(self):
  470. """Generate bprop for DepthToSpace"""
  471. op = P.SpaceToDepth(self.block_size)
  472. def bprop(x, out, dout):
  473. return (op(dout),)
  474. return bprop
  475. @bprop_getters.register(P.Diag)
  476. def get_bprop_diag(self):
  477. """Generate bprop for Diag"""
  478. op = P.DiagPart()
  479. def bprop(x, out, dout):
  480. return (op(dout),)
  481. return bprop
  482. @bprop_getters.register(P.DiagPart)
  483. def get_bprop_diag_part(self):
  484. """Generate bprop for DiagPart"""
  485. op = P.Diag()
  486. def bprop(x, out, dout):
  487. return (op(dout),)
  488. return bprop
  489. def _GatherDropNegatives(params,
  490. ids,
  491. zero_clipped_indices=None,
  492. is_positive=None):
  493. """Helper function for unsorted segment ops."""
  494. maximum = P.Maximum()
  495. gather = P.GatherV2()
  496. greater_equal = P.GreaterEqual()
  497. rank = P.Rank()
  498. fill = P.Fill()
  499. select = P.Select()
  500. if zero_clipped_indices is None:
  501. zero_clipped_indices = maximum(ids, zeros_like(ids))
  502. gathered = gather(params, zero_clipped_indices, 0)
  503. if is_positive is None:
  504. is_positive = greater_equal(ids, 0)
  505. is_positive_shape = shape_op(is_positive)
  506. broadcastable_shape = is_positive_shape
  507. for _ in range(rank(gathered) - rank(is_positive)):
  508. broadcastable_shape += (1,)
  509. is_positive = reshape(is_positive, broadcastable_shape)
  510. gathered_shape = shape_op(gathered)
  511. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  512. zero_slice = zeros_like(gathered)
  513. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  514. @bprop_getters.register(P.UnsortedSegmentSum)
  515. def get_bprop_unsorted_segment_sum(self):
  516. """Generate bprop for UnsortedSegmentSum"""
  517. def bprop(x, segment_ids, num_segments, out, dout):
  518. return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
  519. return bprop
  520. @bprop_getters.register(P.UnsortedSegmentMin)
  521. def get_bprop_unsorted_segment_min(self):
  522. """Generate bprop for UnsortedSegmentMin"""
  523. equal = P.Equal()
  524. cast = P.Cast()
  525. divide = P.RealDiv()
  526. get_dtype = P.DType()
  527. select = P.Select()
  528. def bprop(x, segment_ids, num_segments, out, dout):
  529. gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None)
  530. is_selected = equal(x, gathered_outputs)
  531. is_selected = logical_and(is_selected, is_positive)
  532. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  533. segment_ids, num_segments)
  534. weighted_grads = divide(dout, num_selected)
  535. gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
  536. zero_clipped_indices, is_positive)
  537. zeros = zeros_like(gathered_grads)
  538. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  539. return bprop
  540. @bprop_getters.register(P.UnsortedSegmentProd)
  541. def get_bprop_unsorted_segment_prod(self):
  542. """Generate bprop for UnsortedSegmentProd"""
  543. equal = P.Equal()
  544. cast = P.Cast()
  545. select = P.Select()
  546. gather = P.GatherV2()
  547. greater = P.Greater()
  548. ones_like = P.OnesLike()
  549. maximum = P.Maximum()
  550. unsorted_segment_prod = P.UnsortedSegmentProd()
  551. def bprop(x, segment_ids, num_segments, out, dout):
  552. is_zero = equal(x, 0)
  553. num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
  554. grad = select(greater(num_zero, 1), zeros_like(dout), dout)
  555. non_zero_data = select(is_zero, ones_like(x), x)
  556. non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
  557. zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
  558. gathered_prod = gather(out, zero_clipped_indices, 0)
  559. gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
  560. prod_divided_by_x = gathered_prod / x
  561. partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
  562. gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices)
  563. dx = gathered_grad * partial_derivative
  564. return dx, zeros_like(segment_ids), zeros_like(num_segments)
  565. return bprop
  566. @bprop_getters.register(P.SpaceToBatch)
  567. def get_bprop_space_to_batch(self):
  568. """Generate bprop for SpaceToBatch"""
  569. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  570. def bprop(x, out, dout):
  571. dx = space_to_batch_grad(dout)
  572. return (dx,)
  573. return bprop
  574. @bprop_getters.register(P.BatchToSpace)
  575. def get_bprop_batch_to_space(self):
  576. """Generate bprop for BatchToSpace"""
  577. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  578. def bprop(x, out, dout):
  579. dx = batch_to_space_grad(dout)
  580. return (dx,)
  581. return bprop
  582. @bprop_getters.register(P.SpaceToBatchND)
  583. def get_bprop_space_to_batch_nd(self):
  584. """Generate bprop for SpaceToBatchND"""
  585. space_to_batch_nd_grad = P.BatchToSpaceND(self.ori_block_shape, self.ori_paddings)
  586. def bprop(x, out, dout):
  587. dx = space_to_batch_nd_grad(dout)
  588. return (dx,)
  589. return bprop
  590. @bprop_getters.register(P.BatchToSpaceND)
  591. def get_bprop_batch_to_space_nd(self):
  592. """Generate bprop for BatchToSpaceND"""
  593. batch_to_space_nd_grad = P.SpaceToBatchND(self.ori_block_shape, self.ori_crops)
  594. def bprop(x, out, dout):
  595. dx = batch_to_space_nd_grad(dout)
  596. return (dx,)
  597. return bprop
  598. @bprop_getters.register(P.BroadcastTo)
  599. def get_bprop_broadcast_to(self):
  600. """Generate bprop for BroadcastTo"""
  601. reduce_keep_dim = P.ReduceSum(keep_dims=True)
  602. broadcast_shape = self.shape
  603. def bprop(x, out, dout):
  604. x_shape = shape_op(x)
  605. dout_shape = shape_op(dout)
  606. if x_shape == dout_shape:
  607. return (dout,)
  608. _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
  609. reduced_grad = reduce_keep_dim(dout, reduction_axes)
  610. dx = reshape(reduced_grad, x_shape)
  611. return (dx,)
  612. return bprop
  613. @bprop_getters.register(P.ReverseSequence)
  614. def get_bprop_reverse_sequence(self):
  615. """Generate bprop for ReverseSequence"""
  616. reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
  617. def bprop(x, seq_lengths, out, dout):
  618. dx = reverse_sequence_grad(dout, seq_lengths)
  619. return dx, zeros_like(seq_lengths)
  620. return bprop
  621. @bprop_getters.register(P.TransShape)
  622. def get_bprop_trans_shape(self):
  623. """Generate bprop for TransShape"""
  624. op = P.TransShape()
  625. def bprop(x, shape, out, dout):
  626. dx = op(dout, shape_op(x))
  627. return (dx, zeros_like(shape))
  628. return bprop
  629. @bprop_getters.register(P.Unique)
  630. def get_bprop_unique(self):
  631. """Generate bprop for Unique"""
  632. op = G.UniqueGrad()
  633. def bprop(x, out, dout):
  634. dx = op(dout, out)
  635. return (dx,)
  636. return bprop