Merge pull request !2884 from liuxiao93/UnsortedSegmentProdtags/v0.6.0-beta
| @@ -84,6 +84,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"transpose", "transpose_d"}, | {"transpose", "transpose_d"}, | ||||
| {"fill", "fill_d"}, | {"fill", "fill_d"}, | ||||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | ||||
| {"unsorted_segment_prod", "unsorted_segment_prod_d"}, | |||||
| {"concat", "concat_d"}, | {"concat", "concat_d"}, | ||||
| {"slice", "slice_d"}, | {"slice", "slice_d"}, | ||||
| {"reduce_sum", "reduce_sum_d"}, | {"reduce_sum", "reduce_sum_d"}, | ||||
| @@ -625,6 +625,36 @@ def get_bprop_unsorted_segment_min(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.UnsortedSegmentProd) | |||||
| def get_bprop_unsorted_segment_prod(self): | |||||
| """Generate bprop for UnsortedSegmentProd""" | |||||
| equal = P.Equal() | |||||
| cast = P.Cast() | |||||
| select = P.Select() | |||||
| gather = P.GatherV2() | |||||
| greater = P.Greater() | |||||
| ones_like = P.OnesLike() | |||||
| maximum = P.Maximum() | |||||
| unsorted_segment_prod = P.UnsortedSegmentProd() | |||||
| def bprop(x, segment_ids, num_segments, out, dout): | |||||
| is_zero = equal(x, 0) | |||||
| num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments) | |||||
| grad = select(greater(num_zero, 1), zeros_like(dout), dout) | |||||
| non_zero_data = select(is_zero, ones_like(x), x) | |||||
| non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments) | |||||
| zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids)) | |||||
| gathered_prod = gather(out, zero_clipped_indices, 0) | |||||
| gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) | |||||
| prod_divided_by_x = gathered_prod / x | |||||
| partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) | |||||
| gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices) | |||||
| dx = gathered_grad * partial_derivative | |||||
| return dx, zeros_like(segment_ids), zeros_like(num_segments) | |||||
| return bprop | |||||
| @bprop_getters.register(P.SpaceToBatch) | @bprop_getters.register(P.SpaceToBatch) | ||||
| def get_bprop_space_to_batch(self): | def get_bprop_space_to_batch(self): | ||||
| """Generate bprop for SpaceToBatch""" | """Generate bprop for SpaceToBatch""" | ||||
| @@ -133,6 +133,7 @@ from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | |||||
| from .apply_proximal_adagrad import _apply_proximal_adagrad | from .apply_proximal_adagrad import _apply_proximal_adagrad | ||||
| from .transpose_d import _transpose_d_tbe | from .transpose_d import _transpose_d_tbe | ||||
| from .unsorted_segment_sum import _unsorted_segment_sum_tbe | from .unsorted_segment_sum import _unsorted_segment_sum_tbe | ||||
| from .unsorted_segment_prod import _unsorted_segment_prod_tbe | |||||
| from .logsoftmax_grad import _logsoftmax_grad_tbe | from .logsoftmax_grad import _logsoftmax_grad_tbe | ||||
| from .logsoftmax import _logsoftmax_tbe | from .logsoftmax import _logsoftmax_tbe | ||||
| from .select import _select_tbe | from .select import _select_tbe | ||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """UnsortedSegmentProdD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("unsorted_segment_prod_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("unsorted_segment_prod_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("num_segments", "required", "int", "all") \ | |||||
| .input(0, "data", False, "required", "all") \ | |||||
| .input(1, "segment_ids", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \ | |||||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(unsorted_segment_prod_d_op_info) | |||||
| def _unsorted_segment_prod_tbe(): | |||||
| """UnsortedSegmentProdD TBE register""" | |||||
| return | |||||
| @@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Shape, Size, Slice, Split, TransShape, | Shape, Size, Slice, Split, TransShape, | ||||
| ParallelConcat, | ParallelConcat, | ||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | Squeeze, StridedSlice, Tile, TensorScatterUpdate, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | |||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| @@ -250,6 +250,7 @@ __all__ = [ | |||||
| 'DepthwiseConv2dNative', | 'DepthwiseConv2dNative', | ||||
| 'UnsortedSegmentSum', | 'UnsortedSegmentSum', | ||||
| 'UnsortedSegmentMin', | 'UnsortedSegmentMin', | ||||
| 'UnsortedSegmentProd', | |||||
| "AllGather", | "AllGather", | ||||
| "AllReduce", | "AllReduce", | ||||
| "ReduceScatter", | "ReduceScatter", | ||||
| @@ -1412,6 +1412,58 @@ class UnsortedSegmentMin(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class UnsortedSegmentProd(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes the product along segments of a tensor. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. | |||||
| With float16, float32 or int32 data type. | |||||
| - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32. | |||||
| - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, | |||||
| should be greater than 0. | |||||
| Outputs: | |||||
| Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) | |||||
| >>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32)) | |||||
| >>> num_segments = 2 | |||||
| >>> unsorted_segment_prod = P.UnsortedSegmentProd() | |||||
| >>> unsorted_segment_prod(input_x, segment_ids, num_segments) | |||||
| [[4., 4., 3.], [4., 5., 6.]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init UnsortedSegmentProd""" | |||||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | |||||
| def __infer__(self, x, segment_ids, num_segments): | |||||
| x_type = x['dtype'] | |||||
| x_shape = x['shape'] | |||||
| segment_ids_shape = segment_ids['shape'] | |||||
| validator.check_subclass("input_x", x_type, mstype.tensor, self.name) | |||||
| validator.check_value_type("x_shape", x_shape, [list], self.name) | |||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) | |||||
| validator.check(f'first shape of input_x', x_shape[0], | |||||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||||
| num_segments_v = num_segments['value'] | |||||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | |||||
| validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) | |||||
| segment_ids_shape_len = len(segment_ids_shape) | |||||
| out_shape = [num_segments_v] | |||||
| out_shape += x_shape[segment_ids_shape_len:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': mstype.tensor_type(x_type.element_type()), | |||||
| 'value': None} | |||||
| return out | |||||
| class Concat(PrimitiveWithInfer): | class Concat(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Concat tensor in specified axis. | Concat tensor in specified axis. | ||||
| @@ -1327,6 +1327,11 @@ test_case_nn_ops = [ | |||||
| 'desc_const': [4], | 'desc_const': [4], | ||||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], | 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], | ||||
| 'desc_bprop': [[4, 2, 1, 3]]}), | 'desc_bprop': [[4, 2, 1, 3]]}), | ||||
| ('UnsortedSegmentProd', { | |||||
| 'block': P.UnsortedSegmentProd(), | |||||
| 'desc_const': [4], | |||||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([0, 1, 0]).astype(np.int32))], | |||||
| 'desc_bprop': [[4, 2, 1, 3]]}), | |||||
| ('DropoutGenMask', { | ('DropoutGenMask', { | ||||
| 'block': P.DropoutGenMask(), | 'block': P.DropoutGenMask(), | ||||
| 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], | 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], | ||||