From 2b0ecfd2b139024303bb1f74502bfa738154fe61 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Mon, 6 Jul 2020 11:49:17 +0800 Subject: [PATCH] Add TBE op UnsortedSegmentProd for VM. --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ops/_grad/grad_array_ops.py | 30 +++++++++++ mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/unsorted_segment_prod.py | 48 +++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 52 +++++++++++++++++++ tests/ut/python/ops/test_ops.py | 5 ++ 7 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index c38f48763e..052b7eb2df 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -84,6 +84,7 @@ static std::map tbe_func_adapter_map = { {"transpose", "transpose_d"}, {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, + {"unsorted_segment_prod", "unsorted_segment_prod_d"}, {"concat", "concat_d"}, {"slice", "slice_d"}, {"reduce_sum", "reduce_sum_d"}, diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index e216a4f0d0..6a89ac9309 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -625,6 +625,36 @@ def get_bprop_unsorted_segment_min(self): return bprop +@bprop_getters.register(P.UnsortedSegmentProd) +def get_bprop_unsorted_segment_prod(self): + """Generate bprop for UnsortedSegmentProd""" + equal = P.Equal() + cast = P.Cast() + select = P.Select() + gather = P.GatherV2() + greater = P.Greater() + ones_like = P.OnesLike() + maximum = P.Maximum() + unsorted_segment_prod = P.UnsortedSegmentProd() + + def bprop(x, segment_ids, num_segments, out, dout): + is_zero = equal(x, 0) + num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments) + grad = select(greater(num_zero, 1), zeros_like(dout), dout) + non_zero_data = select(is_zero, ones_like(x), x) + non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments) + zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids)) + gathered_prod = gather(out, zero_clipped_indices, 0) + gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) + prod_divided_by_x = gathered_prod / x + partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) + gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices) + dx = gathered_grad * partial_derivative + return dx, zeros_like(segment_ids), zeros_like(num_segments) + + return bprop + + @bprop_getters.register(P.SpaceToBatch) def get_bprop_space_to_batch(self): """Generate bprop for SpaceToBatch""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 76cea197ba..12bf4df9a1 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -133,6 +133,7 @@ from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad from .apply_proximal_adagrad import _apply_proximal_adagrad from .transpose_d import _transpose_d_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe +from .unsorted_segment_prod import _unsorted_segment_prod_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe from .logsoftmax import _logsoftmax_tbe from .select import _select_tbe diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py new file mode 100644 index 0000000000..40b04d17c3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UnsortedSegmentProdD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("unsorted_segment_prod_d.so") \ + .compute_cost(10) \ + .kernel_name("unsorted_segment_prod_d") \ + .partial_flag(True) \ + .attr("num_segments", "required", "int", "all") \ + .input(0, "data", False, "required", "all") \ + .input(1, "segment_ids", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(unsorted_segment_prod_d_op_info) +def _unsorted_segment_prod_tbe(): + """UnsortedSegmentProdD TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index fe224e8850..21a1ca6505 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, TransShape, Squeeze, StridedSlice, Tile, TensorScatterUpdate, - Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, + Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, @@ -249,6 +249,7 @@ __all__ = [ 'DepthwiseConv2dNative', 'UnsortedSegmentSum', 'UnsortedSegmentMin', + 'UnsortedSegmentProd', "AllGather", "AllReduce", "ReduceScatter", diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b30a03d604..128ba479a5 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1412,6 +1412,58 @@ class UnsortedSegmentMin(PrimitiveWithInfer): return out +class UnsortedSegmentProd(PrimitiveWithInfer): + """ + Computes the product along segments of a tensor. + + Inputs: + - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. + With float16, float32 or int32 data type. + - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32. + - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, + should be greater than 0. + + Outputs: + Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. + + Examples: + >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) + >>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32)) + >>> num_segments = 2 + >>> unsorted_segment_prod = P.UnsortedSegmentProd() + >>> unsorted_segment_prod(input_x, segment_ids, num_segments) + [[4., 4., 3.], [4., 5., 6.]] + """ + + @prim_attr_register + def __init__(self): + """init UnsortedSegmentProd""" + self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) + + def __infer__(self, x, segment_ids, num_segments): + x_type = x['dtype'] + x_shape = x['shape'] + segment_ids_shape = segment_ids['shape'] + validator.check_subclass("input_x", x_type, mstype.tensor, self.name) + validator.check_value_type("x_shape", x_shape, [list], self.name) + valid_type = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) + validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) + num_segments_v = num_segments['value'] + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) + segment_ids_shape_len = len(segment_ids_shape) + out_shape = [num_segments_v] + out_shape += x_shape[segment_ids_shape_len:] + out = {'shape': out_shape, + 'dtype': mstype.tensor_type(x_type.element_type()), + 'value': None} + return out + + class Concat(PrimitiveWithInfer): r""" Concat tensor in specified axis. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index fa79275ce3..c746ca7689 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1318,6 +1318,11 @@ test_case_nn_ops = [ 'desc_const': [4], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], 'desc_bprop': [[4, 2, 1, 3]]}), + ('UnsortedSegmentProd', { + 'block': P.UnsortedSegmentProd(), + 'desc_const': [4], + 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([0, 1, 0]).astype(np.int32))], + 'desc_bprop': [[4, 2, 1, 3]]}), ('DropoutGenMask', { 'block': P.DropoutGenMask(), 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)],