diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 73afd6aff0..6d03cfa27a 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -675,10 +675,10 @@ def get_bprop_diag_part(self): return bprop -def _GatherDropNegatives(params, - ids, - zero_clipped_indices=None, - is_positive=None): +def _gather_drop_negatives(params, + ids, + zero_clipped_indices=None, + is_positive=None): """Helper function for unsorted segment ops.""" maximum = P.Maximum() gather = P.GatherV2() @@ -703,12 +703,32 @@ def _GatherDropNegatives(params, return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) +def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout): + """Gradient for UnsortedSegmentMin or UnsortedSegmentMax""" + equal = P.Equal() + cast = P.Cast() + divide = P.RealDiv() + get_dtype = P.DType() + select = P.Select() + + gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None) + is_selected = equal(x, gathered_outputs) + is_selected = logical_and(is_selected, is_positive) + num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), + segment_ids, num_segments) + weighted_grads = divide(dout, num_selected) + gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None, + zero_clipped_indices, is_positive) + zeros = zeros_like(gathered_grads) + return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) + + @bprop_getters.register(P.UnsortedSegmentSum) def get_bprop_unsorted_segment_sum(self): """Generate bprop for UnsortedSegmentSum""" def bprop(x, segment_ids, num_segments, out, dout): - return _GatherDropNegatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) + return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments) return bprop @@ -716,23 +736,20 @@ def get_bprop_unsorted_segment_sum(self): @bprop_getters.register(P.UnsortedSegmentMin) def get_bprop_unsorted_segment_min(self): """Generate bprop for UnsortedSegmentMin""" - equal = P.Equal() - cast = P.Cast() - divide = P.RealDiv() - get_dtype = P.DType() - select = P.Select() def bprop(x, segment_ids, num_segments, out, dout): - gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids, None, None) - is_selected = equal(x, gathered_outputs) - is_selected = logical_and(is_selected, is_positive) - num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), - segment_ids, num_segments) - weighted_grads = divide(dout, num_selected) - gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, - zero_clipped_indices, is_positive) - zeros = zeros_like(gathered_grads) - return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) + return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) + + return bprop + + +@bprop_getters.register(P.UnsortedSegmentMax) +def get_bprop_unsorted_segment_max(self): + """Generate bprop for UnsortedSegmentMax""" + + def bprop(x, segment_ids, num_segments, out, dout): + return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) + return bprop @@ -759,7 +776,7 @@ def get_bprop_unsorted_segment_prod(self): gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) prod_divided_by_x = gathered_prod / x partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) - gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices) + gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices) dx = gathered_grad * partial_derivative return dx, zeros_like(segment_ids), zeros_like(num_segments) diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 3ce89abed3..108b3ae257 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -272,6 +272,7 @@ from .reduce_all import _reduce_all_tbe from .reduce_any import _reduce_any_tbe from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe from .unsorted_segment_min import _unsorted_segment_min_tbe +from .unsorted_segment_max import _unsorted_segment_max_tbe from .asin import _asin_tbe from .asin_grad import _asin_grad_tbe from .asinh import _asinh_tbe diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py new file mode 100644 index 0000000000..63596fdb70 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_max.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UnsortedSegmentMax op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +unsorted_segment_max_op_info = TBERegOp("UnsortedSegmentMax") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("unsorted_segment_max_d.so") \ + .compute_cost(10) \ + .kernel_name("unsorted_segment_max_d") \ + .partial_flag(True) \ + .attr("num_segments", "required", "int", "all") \ + .input(0, "data", False, "required", "all") \ + .input(1, "segment_ids", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(unsorted_segment_max_op_info) +def _unsorted_segment_max_tbe(): + """UnsortedSegmentMax TBE register""" + return diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 94cff9ff0f..b4cdd45e12 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1648,6 +1648,11 @@ test_case_nn_ops = [ 'desc_const': [4], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], 'desc_bprop': [[4, 2, 1, 3]]}), + ('UnsortedSegmentMax', { + 'block': P.UnsortedSegmentMax(), + 'desc_const': [4], + 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], + 'desc_bprop': [[4, 2, 1, 3]]}), ('UnsortedSegmentProd', { 'block': P.UnsortedSegmentProd(), 'desc_const': [4],