|
|
|
@@ -779,7 +779,8 @@ def get_bprop_unsorted_segment_sum(self): |
|
|
|
"""Generate bprop for UnsortedSegmentSum""" |
|
|
|
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout): |
|
|
|
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \ |
|
|
|
zeros_like(num_segments) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
@@ -827,7 +828,7 @@ def get_bprop_unsorted_segment_prod(self): |
|
|
|
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) |
|
|
|
prod_divided_by_x = gathered_prod / x |
|
|
|
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) |
|
|
|
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices) |
|
|
|
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None) |
|
|
|
dx = gathered_grad * partial_derivative |
|
|
|
return dx, zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
|
|
|
|
|