Browse Source

fix-bug-of-gather-drop-negatives-without-default-parameter

tags/v1.2.0-rc1
lvliang 5 years ago
parent
commit
7da4f49c41
2 changed files with 4 additions and 4 deletions
  1. +1
    -2
      mindspore/ccsrc/pipeline/jit/pass.cc
  2. +3
    -2
      mindspore/ops/_grad/grad_array_ops.py

+ 1
- 2
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -485,8 +485,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};

std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup},
{"opt_a", OptPassAGroup},
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},


+ 3
- 2
mindspore/ops/_grad/grad_array_ops.py View File

@@ -779,7 +779,8 @@ def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""

def bprop(x, segment_ids, num_segments, out, dout):
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
zeros_like(num_segments)

return bprop

@@ -827,7 +828,7 @@ def get_bprop_unsorted_segment_prod(self):
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
prod_divided_by_x = gathered_prod / x
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices)
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
dx = gathered_grad * partial_derivative
return dx, zeros_like(segment_ids), zeros_like(num_segments)



Loading…
Cancel
Save