Browse Source

unsortedsegsum grad

tags/v0.7.0-beta
fangzehua 5 years ago
parent
commit
99f2be7064
2 changed files with 12 additions and 4 deletions
  1. +10
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +2
    -4
      tests/ut/python/ops/test_ops.py

+ 10
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -673,6 +673,16 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)


@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""

def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)

return bprop


@bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin"""


+ 2
- 4
tests/ut/python/ops/test_ops.py View File

@@ -1447,14 +1447,12 @@ test_case_nn_ops = [
'block': P.UnsortedSegmentSum(),
'desc_const': [1280],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]],
'skip': ['backward']}),
'desc_bprop': [[1280, 1024]]}),
('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(),
'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}),
'desc_bprop': [[4, 1, 3]]}),
('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(),
'desc_const': [4],


Loading…
Cancel
Save