Browse Source

!4005 unsortedsegsum grad

Merge pull request !4005 from fangzehua/unsortedsegsum
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0b407dfe78
2 changed files with 12 additions and 4 deletions
  1. +10
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +2
    -4
      tests/ut/python/ops/test_ops.py

+ 10
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -673,6 +673,16 @@ def _GatherDropNegatives(params,
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)




@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""

def bprop(x, segment_ids, num_segments, out, dout):
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)

return bprop


@bprop_getters.register(P.UnsortedSegmentMin) @bprop_getters.register(P.UnsortedSegmentMin)
def get_bprop_unsorted_segment_min(self): def get_bprop_unsorted_segment_min(self):
"""Generate bprop for UnsortedSegmentMin""" """Generate bprop for UnsortedSegmentMin"""


+ 2
- 4
tests/ut/python/ops/test_ops.py View File

@@ -1448,14 +1448,12 @@ test_case_nn_ops = [
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
'desc_const': [1280], 'desc_const': [1280],
'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))], 'desc_inputs': [[1280, 1024], Tensor(np.ones(1280).astype(np.int32))],
'desc_bprop': [[8192, 1024]],
'skip': ['backward']}),
'desc_bprop': [[1280, 1024]]}),
('UnsortedSegmentSum_1', { ('UnsortedSegmentSum_1', {
'block': P.UnsortedSegmentSum(), 'block': P.UnsortedSegmentSum(),
'desc_const': [4], 'desc_const': [4],
'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))],
'desc_bprop': [[4, 1, 3]],
'skip': ['backward']}),
'desc_bprop': [[4, 1, 3]]}),
('UnsortedSegmentMin', { ('UnsortedSegmentMin', {
'block': P.UnsortedSegmentMin(), 'block': P.UnsortedSegmentMin(),
'desc_const': [4], 'desc_const': [4],


Loading…
Cancel
Save