You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. from .. import operations as P
  17. from ..operations import _grad_ops as G
  18. from ..operations import _inner_ops as inner
  19. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  20. from ..functional import broadcast_gradient_args
  21. from .. import functional as F
  22. from .grad_base import bprop_getters
  23. from ..primitive import constexpr
  24. from ... import context
  25. from ...common import dtype as mstype
  26. reduce_sum = P.ReduceSum()
  27. unsorted_segment_sum = P.UnsortedSegmentSum()
  28. transpose = P.Transpose()
  29. shape_op = P.Shape()
  30. reshape = P.Reshape()
  31. size_op = P.Size()
  32. invert_permutation = P.InvertPermutation()
  33. logical_and = P.LogicalAnd()
  34. @bprop_getters.register(P.Fill)
  35. def get_bprop_fill(self):
  36. """Generate bprop for Fill"""
  37. def bprop(dtype, dims, x, out, dout):
  38. return zeros_like(dims), zeros_like(x)
  39. return bprop
  40. @bprop_getters.register(P.DType)
  41. def get_bprop_dtype(self):
  42. """Generate bprop for DType"""
  43. def bprop(x, out, dout):
  44. return (zeros_like(x),)
  45. return bprop
  46. @bprop_getters.register(P.Cast)
  47. def get_bprop_cast(self):
  48. """Generate bprop for Cast"""
  49. cast = P.Cast()
  50. get_dtype = P.DType()
  51. def bprop(x, t, out, dout):
  52. dx = cast(dout, get_dtype(x))
  53. return dx, zeros_like(t)
  54. return bprop
  55. @bprop_getters.register(P.Shape)
  56. def get_bprop_shape(self):
  57. """Generate bprop for Shape"""
  58. def bprop(x, out, dout):
  59. return (zeros_like(x),)
  60. return bprop
  61. @bprop_getters.register(P.Split)
  62. def get_bprop_split(self):
  63. """Generate bprop for Split"""
  64. axis = self.axis
  65. def bprop(x, out, dout):
  66. concat_op = P.Concat(axis)
  67. dx = concat_op(dout)
  68. return (dx,)
  69. return bprop
  70. @bprop_getters.register(P.Rank)
  71. def get_bprop_rank(self):
  72. """Generate bprop for Rank"""
  73. def bprop(x, out, dout):
  74. return (zeros_like(x),)
  75. return bprop
  76. @bprop_getters.register(P.Reshape)
  77. def get_bprop_reshape(self):
  78. """Generate bprop for Reshape"""
  79. def bprop(x, shp, out, dout):
  80. shapex = shape_op(x)
  81. return reshape(dout, shapex), zeros_like(shp)
  82. return bprop
  83. @bprop_getters.register(P.ExpandDims)
  84. def get_bprop_expand_dims(self):
  85. """Generate bprop for ExpandDims"""
  86. def bprop(x, axis, out, dout):
  87. shapex = shape_op(x)
  88. return reshape(dout, shapex), zeros_like(axis)
  89. return bprop
  90. @bprop_getters.register(P.Squeeze)
  91. def get_bprop_squeeze(self):
  92. """Generate bprop for Squeeze"""
  93. def bprop(x, out, dout):
  94. shapex = shape_op(x)
  95. return (reshape(dout, shapex),)
  96. return bprop
  97. @bprop_getters.register(P.Flatten)
  98. def get_bprop_flatten(self):
  99. """Generate bprop for Flatten"""
  100. flatten_grad = G.FlattenGrad()
  101. def bprop(x, out, dout):
  102. dx = flatten_grad(dout, shape_op(x))
  103. return (dx,)
  104. return bprop
  105. @constexpr
  106. def _tile_shape(multiples, shapex):
  107. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  108. len_muli = len(multiples)
  109. rank = len(shapex)
  110. len_cmp = len_muli - rank
  111. max_len = max(len_muli, rank)
  112. i = 0
  113. j = 0
  114. ret = []
  115. while (i < max_len) and (j < max_len):
  116. if len_cmp == 0:
  117. ret.append(multiples[i])
  118. ret.append(shapex[j])
  119. i += 1
  120. j += 1
  121. elif len_cmp > 0:
  122. ret.append(multiples[i])
  123. ret.append(1)
  124. i += 1
  125. len_cmp -= 1
  126. else:
  127. ret.append(1)
  128. ret.append(shapex[j])
  129. len_cmp += 1
  130. return tuple(ret)
  131. @bprop_getters.register(P.Tile)
  132. def get_bprop_tile(self):
  133. """Generate bprop for Tile"""
  134. def bprop(x, multiples, out, dout):
  135. shapex = shape_op(x)
  136. r_shape = _tile_shape(multiples, shapex)
  137. # 0 represents the start index, and 2 represents the step
  138. axis = F.make_range(0, len(r_shape), 2)
  139. dx = reduce_sum(reshape(dout, r_shape), axis)
  140. dx = reshape(dx, shapex)
  141. return dx, zeros_like(multiples)
  142. return bprop
  143. @bprop_getters.register(inner.EmbeddingLookup)
  144. def get_bprop_embedding_lookup(self):
  145. """Generate bprop for EmbeddingLookup"""
  146. sub_op = P.Sub()
  147. reshape_op = P.Reshape()
  148. host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
  149. def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
  150. x_shp = shape_op(x)
  151. new_indices = sub_op(indices, offset)
  152. # Reshape the 'new_indices'
  153. new_indices_shape_changed = (size_op(new_indices),)
  154. new_indices = reshape_op(new_indices, new_indices_shape_changed)
  155. x_shp_tail = x_shp[1:]
  156. actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
  157. if reduce_scatter_flag is True:
  158. # On host
  159. elu_grad = G.EmbeddingLookupCommGrad()
  160. actual_dout = elu_grad(dout, split_num)
  161. # Reshape the 'actual_dout' on host
  162. actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
  163. else:
  164. # Reshape the 'actual_dout' on device
  165. actual_dout = reshape_op(dout, actual_dout_shape_changed)
  166. return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
  167. zeros_like(reduce_scatter_flag), zeros_like(split_num)
  168. return bprop_sparse
  169. @bprop_getters.register(P.Transpose)
  170. def get_bprop_transpose(self):
  171. """Generate bprop for Transpose"""
  172. def bprop(x, perm, out, dout):
  173. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  174. return bprop
  175. @bprop_getters.register(P.Concat)
  176. def get_bprop_concat(self):
  177. """Generate bprop for Concat"""
  178. axis = self.axis
  179. def bprop(x, out, dout):
  180. dx = ()
  181. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  182. for i in range(F.tuple_len(x)):
  183. slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
  184. dx = dx + (slice_out,)
  185. return (dx,)
  186. return bprop
  187. @constexpr
  188. def _slice_grad_pad(begins, sizes, shapes):
  189. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  190. return pads
  191. @bprop_getters.register(P.Slice)
  192. def get_bprop_slice(self):
  193. """Generate bprop for Slice"""
  194. def bprop(x, begin, size, out, dout):
  195. dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
  196. return (dx, zeros_like(begin), zeros_like(size))
  197. def bprop_grad(x, begin, size, out, dout):
  198. dx = dx = G.SliceGrad()(dout, x, begin, size)
  199. return (dx, zeros_like(begin), zeros_like(size))
  200. if context.get_context('device_target') == "GPU" or context.get_context('device_target') == "CPU":
  201. return bprop_grad
  202. return bprop
  203. @constexpr
  204. def _generate_shape_index(out_shape, indices_shape, axis):
  205. out_rank = len(out_shape)
  206. ind_rank = len(indices_shape)
  207. if axis < 0:
  208. axis += out_rank - ind_rank + 1
  209. perm_part1 = tuple(range(axis, axis + ind_rank))
  210. index = tuple(range(out_rank))
  211. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  212. return perm
  213. @constexpr
  214. def _generate_inverse_index(x_shape, axis):
  215. x_rank = len(x_shape)
  216. index = tuple(range(x_rank))
  217. if axis < 0:
  218. axis += x_rank
  219. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  220. return perm
  221. @bprop_getters.register(P.GatherV2)
  222. def get_bprop_gather_v2(self):
  223. """Generate bprop for GatherV2"""
  224. def bprop(x, indices, axis, out, dout):
  225. if F.rank(dout) == 0:
  226. dout = P.ExpandDims()(dout, -1)
  227. if F.rank(indices) == 0:
  228. indices = P.ExpandDims()(indices, -1)
  229. x_shp = shape_op(x)
  230. out_shp = shape_op(dout)
  231. ind_shp = shape_op(indices)
  232. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  233. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  234. values_transpose = transpose(dout, perm_1)
  235. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  236. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  237. perm_2 = _generate_inverse_index(x_shp, axis)
  238. params_grad = transpose(params_grad, perm_2)
  239. return params_grad, zeros_like(indices), zeros_like(axis)
  240. return bprop
  241. @bprop_getters.register(P.SparseGatherV2)
  242. def get_bprop_sparse_gather_v2(self):
  243. """Generate bprop for SparseGatherV2"""
  244. def bprop(x, indices, axis, out, dout):
  245. x_shp = shape_op(x)
  246. if axis == 0:
  247. indices_size = (size_op(indices),)
  248. x_tail_shp = x_shp[1:]
  249. values_shape = indices_size + x_tail_shp
  250. values = reshape(dout, values_shape)
  251. indices = reshape(indices, indices_size)
  252. return (indices, values, x_shp), zeros_like(indices), zeros_like(axis)
  253. if F.rank(dout) == 0:
  254. dout = P.ExpandDims()(dout, -1)
  255. if F.rank(indices) == 0:
  256. indices = P.ExpandDims()(indices, -1)
  257. out_shp = shape_op(dout)
  258. ind_shp = shape_op(indices)
  259. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  260. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  261. values_transpose = transpose(dout, perm_1)
  262. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  263. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  264. perm_2 = _generate_inverse_index(x_shp, axis)
  265. params_grad = transpose(params_grad, perm_2)
  266. return params_grad, zeros_like(indices), zeros_like(axis)
  267. return bprop
  268. @bprop_getters.register(inner.Range)
  269. def get_bprop_range(self):
  270. """Generate bprop for Range"""
  271. def bprop(x, out, dout):
  272. return (zeros_like(x),)
  273. return bprop
  274. @bprop_getters.register(P.Pack)
  275. def get_bprop_pack(self):
  276. """Generate bprop for Pack"""
  277. axis = self.axis
  278. def bprop(x, out, dout):
  279. pack_grad = P.Unpack(axis)
  280. out = pack_grad(dout)
  281. return (out,)
  282. return bprop
  283. @bprop_getters.register(P.Unpack)
  284. def get_bprop_unpack(self):
  285. """Generate bprop for Unpack"""
  286. axis = self.axis
  287. def bprop(x, out, dout):
  288. unpack_grad = P.Pack(axis)
  289. out = unpack_grad(dout)
  290. return (out,)
  291. return bprop
  292. @bprop_getters.register(P.StridedSlice)
  293. def get_bprop_strided_slice(self):
  294. """Generate bprop for StridedSlice"""
  295. input_grad = G.StridedSliceGrad(self.begin_mask,
  296. self.end_mask,
  297. self.ellipsis_mask,
  298. self.new_axis_mask,
  299. self.shrink_axis_mask)
  300. def bprop(x, begin, end, strides, out, dout):
  301. dx = input_grad(dout, shape_op(x), begin, end, strides)
  302. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  303. return bprop
  304. @bprop_getters.register(P.Eye)
  305. def get_bprop_eye(self):
  306. """Generate bprop for Eye"""
  307. def bprop(n, m, t, out, dout):
  308. return zeros_like(n), zeros_like(m), zeros_like(t)
  309. return bprop
  310. @bprop_getters.register(P.Select)
  311. def get_bprop_select(self):
  312. """Generate bprop for Select"""
  313. select = P.Select()
  314. def bprop(cond, x, y, out, dout):
  315. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  316. return bprop
  317. @bprop_getters.register(P.OnesLike)
  318. def get_bprop_oneslike(self):
  319. """Generate bprop for OnesLike"""
  320. def bprop(x, out, dout):
  321. return (zeros_like(x),)
  322. return bprop
  323. @bprop_getters.register(P.ZerosLike)
  324. def get_bprop_zeroslike(self):
  325. """Generate bprop for OnesLike"""
  326. def bprop(x, out, dout):
  327. return (zeros_like(x),)
  328. return bprop
  329. @bprop_getters.register(P.ResizeNearestNeighbor)
  330. def get_bprop_resize_nearest_neighbor(self):
  331. """Generate bprop for ResizeNearestNeighbor"""
  332. op = G.ResizeNearestNeighborGrad(self.align_corners)
  333. def bprop(inputs, out, dout):
  334. shp = shape_op(inputs)
  335. # 2 and 3 represent the height and width
  336. shp = (shp[2], shp[3])
  337. return (op(dout, shp),)
  338. return bprop
  339. @bprop_getters.register(P.GatherNd)
  340. def get_bprop_gather_nd(self):
  341. """Generate bprop for GatherNd"""
  342. op = P.ScatterNd()
  343. def bprop(x, indices, out, dout):
  344. shp = shape_op(x)
  345. return op(indices, dout, shp), zeros_like(indices)
  346. return bprop
  347. @bprop_getters.register(P.ScatterNd)
  348. def get_bprop_scatter_nd(self):
  349. """Generate bprop for ScatterNd"""
  350. op = P.GatherNd()
  351. def bprop(indices, x, shape, out, dout):
  352. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  353. return bprop
  354. @bprop_getters.register(P.ScatterNdUpdate)
  355. def get_bprop_scatter_nd_update(self):
  356. """Generate bprop for ScatterNdUpdate"""
  357. op = P.GatherNd()
  358. def bprop(x, indices, update, out, dout):
  359. return dout, zeros_like(indices), op(dout, indices)
  360. return bprop
  361. @bprop_getters.register(P.TensorScatterUpdate)
  362. def get_bprop_tensor_scatter_update(self):
  363. """Generate bprop for TensorScatterUpdate"""
  364. gather_nd = P.GatherNd()
  365. tensor_scatter_update = P.TensorScatterUpdate()
  366. def bprop(x, indices, update, out, dout):
  367. x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
  368. update_grad = gather_nd(dout, indices)
  369. return x_grad, zeros_like(indices), update_grad
  370. return bprop
  371. @bprop_getters.register(P.ScatterMax)
  372. def get_bprop_scatter_max(self):
  373. """Generate bprop for ScatterMax"""
  374. gather = P.GatherV2()
  375. def bprop(x, indices, update, out, dout):
  376. return dout, zeros_like(indices), gather(dout, indices, 0)
  377. return bprop
  378. @bprop_getters.register(P.Argmax)
  379. def get_bprop_argmax(self):
  380. """Generate bprop for Argmax"""
  381. def bprop(x, out, dout):
  382. return (zeros_like(x),)
  383. return bprop
  384. @bprop_getters.register(P.Argmin)
  385. def get_bprop_argmin(self):
  386. """Generate bprop for Argmin"""
  387. def bprop(x, out, dout):
  388. return (zeros_like(x),)
  389. return bprop
  390. @bprop_getters.register(P.SpaceToDepth)
  391. def get_bprop_space_to_depth(self):
  392. """Generate bprop for SpaceToDepth"""
  393. op = P.DepthToSpace(self.block_size)
  394. def bprop(x, out, dout):
  395. return (op(dout),)
  396. return bprop
  397. @bprop_getters.register(P.DepthToSpace)
  398. def get_bprop_depth_to_space(self):
  399. """Generate bprop for DepthToSpace"""
  400. op = P.SpaceToDepth(self.block_size)
  401. def bprop(x, out, dout):
  402. return (op(dout),)
  403. return bprop
  404. @bprop_getters.register(P.Diag)
  405. def get_bprop_diag(self):
  406. """Generate bprop for Diag"""
  407. op = P.DiagPart()
  408. def bprop(x, out, dout):
  409. return (op(dout),)
  410. return bprop
  411. @bprop_getters.register(P.DiagPart)
  412. def get_bprop_diag_part(self):
  413. """Generate bprop for DiagPart"""
  414. op = P.Diag()
  415. def bprop(x, out, dout):
  416. return (op(dout),)
  417. return bprop
  418. def _GatherDropNegatives(params,
  419. ids,
  420. zero_clipped_indices=None,
  421. is_positive=None):
  422. """Helper function for unsorted segment ops."""
  423. maximum = P.Maximum()
  424. gather = P.GatherV2()
  425. greater_equal = P.GreaterEqual()
  426. rank = P.Rank()
  427. fill = P.Fill()
  428. select = P.Select()
  429. if zero_clipped_indices is None:
  430. zero_clipped_indices = maximum(ids, zeros_like(ids))
  431. gathered = gather(params, zero_clipped_indices, 0)
  432. if is_positive is None:
  433. is_positive = greater_equal(ids, 0)
  434. is_positive_shape = shape_op(is_positive)
  435. broadcastable_shape = is_positive_shape
  436. for _ in range(rank(gathered) - rank(is_positive)):
  437. broadcastable_shape += (1,)
  438. is_positive = reshape(is_positive, broadcastable_shape)
  439. gathered_shape = shape_op(gathered)
  440. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  441. zero_slice = zeros_like(gathered)
  442. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  443. @bprop_getters.register(P.UnsortedSegmentMin)
  444. def get_bprop_unsorted_segment_min(self):
  445. """Generate bprop for UnsortedSegmentMin"""
  446. equal = P.Equal()
  447. cast = P.Cast()
  448. divide = P.RealDiv()
  449. get_dtype = P.DType()
  450. select = P.Select()
  451. def bprop(x, segment_ids, num_segments, out, dout):
  452. gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
  453. is_selected = equal(x, gathered_outputs)
  454. is_selected = logical_and(is_selected, is_positive)
  455. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  456. segment_ids, num_segments)
  457. weighted_grads = divide(dout, num_selected)
  458. gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
  459. zero_clipped_indices, is_positive)
  460. zeros = zeros_like(gathered_grads)
  461. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  462. return bprop
  463. @bprop_getters.register(P.SpaceToBatch)
  464. def get_bprop_space_to_batch(self):
  465. """Generate bprop for SpaceToBatch"""
  466. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  467. def bprop(x, out, dout):
  468. dx = space_to_batch_grad(dout)
  469. return (dx,)
  470. return bprop
  471. @bprop_getters.register(P.BatchToSpace)
  472. def get_bprop_batch_to_space(self):
  473. """Generate bprop for BatchToSpace"""
  474. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  475. def bprop(x, out, dout):
  476. dx = batch_to_space_grad(dout)
  477. return (dx,)
  478. return bprop
  479. @bprop_getters.register(P.SpaceToBatchND)
  480. def get_bprop_space_to_batch_nd(self):
  481. """Generate bprop for SpaceToBatchND"""
  482. space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
  483. def bprop(x, out, dout):
  484. dx = space_to_batch_nd_grad(dout)
  485. return (dx,)
  486. return bprop
  487. @bprop_getters.register(P.BatchToSpaceND)
  488. def get_bprop_batch_to_space_nd(self):
  489. """Generate bprop for BatchToSpaceND"""
  490. batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
  491. def bprop(x, out, dout):
  492. dx = batch_to_space_nd_grad(dout)
  493. return (dx,)
  494. return bprop
  495. @bprop_getters.register(P.BroadcastTo)
  496. def get_bprop_broadcast_to(self):
  497. """Generate bprop for BroadcastTo"""
  498. reduce_keep_dim = P.ReduceSum(keep_dims=True)
  499. broadcast_shape = self.shape
  500. def bprop(x, out, dout):
  501. x_shape = shape_op(x)
  502. dout_shape = shape_op(dout)
  503. if x_shape == dout_shape:
  504. return (dout,)
  505. _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
  506. reduced_grad = reduce_keep_dim(dout, reduction_axes)
  507. dx = reshape(reduced_grad, x_shape)
  508. return (dx,)
  509. return bprop
  510. @bprop_getters.register(P.ReverseSequence)
  511. def get_bprop_reverse_sequence(self):
  512. """Generate bprop for ReverseSequence"""
  513. reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
  514. def bprop(x, seq_lengths, out, dout):
  515. dx = reverse_sequence_grad(dout, seq_lengths)
  516. return dx, zeros_like(seq_lengths)
  517. return bprop