|
|
|
@@ -675,10 +675,10 @@ def get_bprop_diag_part(self): |
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
def _GatherDropNegatives(params, |
|
|
|
ids, |
|
|
|
zero_clipped_indices=None, |
|
|
|
is_positive=None): |
|
|
|
def _gather_drop_negatives(params, |
|
|
|
ids, |
|
|
|
zero_clipped_indices=None, |
|
|
|
is_positive=None): |
|
|
|
"""Helper function for unsorted segment ops.""" |
|
|
|
maximum = P.Maximum() |
|
|
|
gather = P.GatherV2() |
|
|
|
@@ -703,12 +703,32 @@ def _GatherDropNegatives(params, |
|
|
|
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) |
|
|
|
|
|
|
|
|
|
|
|
def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout): |
|
|
|
"""Gradient for UnsortedSegmentMin or UnsortedSegmentMax""" |
|
|
|
equal = P.Equal() |
|
|
|
cast = P.Cast() |
|
|
|
divide = P.RealDiv() |
|
|
|
get_dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None) |
|
|
|
is_selected = equal(x, gathered_outputs) |
|
|
|
is_selected = logical_and(is_selected, is_positive) |
|
|
|
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), |
|
|
|
segment_ids, num_segments) |
|
|
|
weighted_grads = divide(dout, num_selected) |
|
|
|
gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None, |
|
|
|
zero_clipped_indices, is_positive) |
|
|
|
zeros = zeros_like(gathered_grads) |
|
|
|
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.UnsortedSegmentSum) |
|
|
|
def get_bprop_unsorted_segment_sum(self): |
|
|
|
"""Generate bprop for UnsortedSegmentSum""" |
|
|
|
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout): |
|
|
|
return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
@@ -716,23 +736,20 @@ def get_bprop_unsorted_segment_sum(self): |
|
|
|
@bprop_getters.register(P.UnsortedSegmentMin) |
|
|
|
def get_bprop_unsorted_segment_min(self): |
|
|
|
"""Generate bprop for UnsortedSegmentMin""" |
|
|
|
equal = P.Equal() |
|
|
|
cast = P.Cast() |
|
|
|
divide = P.RealDiv() |
|
|
|
get_dtype = P.DType() |
|
|
|
select = P.Select() |
|
|
|
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout): |
|
|
|
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None) |
|
|
|
is_selected = equal(x, gathered_outputs) |
|
|
|
is_selected = logical_and(is_selected, is_positive) |
|
|
|
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), |
|
|
|
segment_ids, num_segments) |
|
|
|
weighted_grads = divide(dout, num_selected) |
|
|
|
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, |
|
|
|
zero_clipped_indices, is_positive) |
|
|
|
zeros = zeros_like(gathered_grads) |
|
|
|
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.UnsortedSegmentMax) |
|
|
|
def get_bprop_unsorted_segment_max(self): |
|
|
|
"""Generate bprop for UnsortedSegmentMax""" |
|
|
|
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout): |
|
|
|
return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@@ -759,7 +776,7 @@ def get_bprop_unsorted_segment_prod(self): |
|
|
|
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) |
|
|
|
prod_divided_by_x = gathered_prod / x |
|
|
|
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) |
|
|
|
gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices) |
|
|
|
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices) |
|
|
|
dx = gathered_grad * partial_derivative |
|
|
|
return dx, zeros_like(segment_ids), zeros_like(num_segments) |
|
|
|
|
|
|
|
|