| @@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||||
| const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | ||||
| const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | ||||
| const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | ||||
| const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||||
| const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | ||||
| const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | ||||
| const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| @@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize; | |||||
| extern const PrimitivePtr kPrimArgMax; | extern const PrimitivePtr kPrimArgMax; | ||||
| extern const PrimitivePtr kPrimPack; | extern const PrimitivePtr kPrimPack; | ||||
| extern const PrimitivePtr kPrimUnpack; | extern const PrimitivePtr kPrimUnpack; | ||||
| extern const PrimitivePtr kPrimUnsortedSegmentMin; | |||||
| extern const PrimitivePtr kPrimUnsortedSegmentSum; | extern const PrimitivePtr kPrimUnsortedSegmentSum; | ||||
| extern const PrimitivePtr kPrimConcatOffset; | extern const PrimitivePtr kPrimConcatOffset; | ||||
| extern const PrimitivePtr kPrimReshape; | extern const PrimitivePtr kPrimReshape; | ||||
| @@ -340,6 +340,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, | {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, | ||||
| {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, | {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, | ||||
| {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, | {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, | ||||
| {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)}, | |||||
| {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, | {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, | ||||
| {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, | {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, | ||||
| {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, | {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, | ||||
| @@ -1048,6 +1048,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int | |||||
| ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; | ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; | ||||
| // UnsortedSegmentMin | |||||
| INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; | |||||
| INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}}; | |||||
| ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ExpandDims | // ExpandDims | ||||
| INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; | INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; | ||||
| ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; | ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; | ||||
| @@ -281,6 +281,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice) | |||||
| DECLARE_OP_ADAPTER(UnsortedSegmentSumD) | DECLARE_OP_ADAPTER(UnsortedSegmentSumD) | ||||
| DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) | DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) | ||||
| DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) | DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) | ||||
| DECLARE_OP_ADAPTER(UnsortedSegmentMinD) | |||||
| DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD) | |||||
| DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD) | |||||
| DECLARE_OP_ADAPTER(ExpandDims) | DECLARE_OP_ADAPTER(ExpandDims) | ||||
| DECLARE_OP_USE_OUTPUT(ExpandDims) | DECLARE_OP_USE_OUTPUT(ExpandDims) | ||||
| DECLARE_OP_ADAPTER(Squeeze) | DECLARE_OP_ADAPTER(Squeeze) | ||||
| @@ -22,6 +22,7 @@ from .. import functional as F | |||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| from ..primitive import constexpr | from ..primitive import constexpr | ||||
| from ... import context | from ... import context | ||||
| from ...common import dtype as mstype | |||||
| reduce_sum = P.ReduceSum() | reduce_sum = P.ReduceSum() | ||||
| unsorted_segment_sum = P.UnsortedSegmentSum() | unsorted_segment_sum = P.UnsortedSegmentSum() | ||||
| @@ -29,6 +30,7 @@ transpose = P.Transpose() | |||||
| shape_op = P.Shape() | shape_op = P.Shape() | ||||
| reshape = P.Reshape() | reshape = P.Reshape() | ||||
| invert_permutation = P.InvertPermutation() | invert_permutation = P.InvertPermutation() | ||||
| logical_and = P.LogicalAnd() | |||||
| @bprop_getters.register(P.Fill) | @bprop_getters.register(P.Fill) | ||||
| @@ -456,6 +458,57 @@ def get_bprop_diag_part(self): | |||||
| return bprop | return bprop | ||||
| def _GatherDropNegatives(params, | |||||
| ids, | |||||
| zero_clipped_indices=None, | |||||
| is_positive=None): | |||||
| """Helper function for unsorted segment ops.""" | |||||
| maximum = P.Maximum() | |||||
| gather = P.GatherV2() | |||||
| greater_equal = P.GreaterEqual() | |||||
| rank = P.Rank() | |||||
| fill = P.Fill() | |||||
| select = P.Select() | |||||
| if zero_clipped_indices is None: | |||||
| zero_clipped_indices = maximum(ids, zeros_like(ids)) | |||||
| gathered = gather(params, zero_clipped_indices, 0) | |||||
| if is_positive is None: | |||||
| is_positive = greater_equal(ids, 0) | |||||
| is_positive_shape = shape_op(is_positive) | |||||
| broadcastable_shape = is_positive_shape | |||||
| for _ in range(rank(gathered) - rank(is_positive)): | |||||
| broadcastable_shape += (1,) | |||||
| is_positive = reshape(is_positive, broadcastable_shape) | |||||
| gathered_shape = shape_op(gathered) | |||||
| is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1)) | |||||
| zero_slice = zeros_like(gathered) | |||||
| return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) | |||||
| @bprop_getters.register(P.UnsortedSegmentMin) | |||||
| def get_bprop_unsorted_segment_min(self): | |||||
| """Generate bprop for UnsortedSegmentMin""" | |||||
| equal = P.Equal() | |||||
| cast = P.Cast() | |||||
| divide = P.RealDiv() | |||||
| get_dtype = P.DType() | |||||
| select = P.Select() | |||||
| def bprop(x, segment_ids, num_segments, out, dout): | |||||
| gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids) | |||||
| is_selected = equal(x, gathered_outputs) | |||||
| is_selected = logical_and(is_selected, is_positive) | |||||
| num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), | |||||
| segment_ids, num_segments) | |||||
| weighted_grads = divide(dout, num_selected) | |||||
| gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, | |||||
| zero_clipped_indices, is_positive) | |||||
| zeros = zeros_like(gathered_grads) | |||||
| return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) | |||||
| return bprop | |||||
| @bprop_getters.register(P.SpaceToBatch) | @bprop_getters.register(P.SpaceToBatch) | ||||
| def get_bprop_space_to_batch(self): | def get_bprop_space_to_batch(self): | ||||
| """Generate bprop for SpaceToBatch""" | """Generate bprop for SpaceToBatch""" | ||||
| @@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, | Shape, Size, Slice, Split, | ||||
| Squeeze, StridedSlice, Tile, | Squeeze, StridedSlice, Tile, | ||||
| Transpose, TruncatedNormal, TupleToArray, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | |||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| @@ -96,6 +96,7 @@ __all__ = [ | |||||
| 'MaxPool', | 'MaxPool', | ||||
| 'TopK', | 'TopK', | ||||
| 'Adam', | 'Adam', | ||||
| 'Softplus', | |||||
| 'Softmax', | 'Softmax', | ||||
| 'LogSoftmax', | 'LogSoftmax', | ||||
| 'SoftmaxCrossEntropyWithLogits', | 'SoftmaxCrossEntropyWithLogits', | ||||
| @@ -210,6 +211,7 @@ __all__ = [ | |||||
| 'Size', | 'Size', | ||||
| 'DepthwiseConv2dNative', | 'DepthwiseConv2dNative', | ||||
| 'UnsortedSegmentSum', | 'UnsortedSegmentSum', | ||||
| 'UnsortedSegmentMin', | |||||
| "AllGather", | "AllGather", | ||||
| "AllReduce", | "AllReduce", | ||||
| "ReduceScatter", | "ReduceScatter", | ||||
| @@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class UnsortedSegmentMin(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes the minimum along segments of a tensor. | |||||
| If the given segment_ids is negative, the value will be ignored. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. | |||||
| - **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`. | |||||
| - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. | |||||
| Outputs: | |||||
| Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) | |||||
| >>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32) | |||||
| >>> num_segments = 2 | |||||
| >>> unsorted_segment_min = P.UnsortedSegmentMin() | |||||
| >>> unsorted_segment_min(input_x, segment_ids, num_segments) | |||||
| [[1., 2., 3.], [4., 2., 1.]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init UnsortedSegmentMin""" | |||||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | |||||
| def __infer__(self, x, segment_ids, num_segments): | |||||
| x_type = x['dtype'] | |||||
| x_shape = x['shape'] | |||||
| segment_ids_shape = segment_ids['shape'] | |||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) | |||||
| num_segments_v = num_segments['value'] | |||||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | |||||
| validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) | |||||
| segment_ids_shape_len = len(segment_ids_shape) | |||||
| out_shape = [num_segments_v] | |||||
| out_shape += x_shape[segment_ids_shape_len:] | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': x_type, | |||||
| 'value': None} | |||||
| return out | |||||
| class Concat(PrimitiveWithInfer): | class Concat(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Concat tensor in specified axis. | Concat tensor in specified axis. | ||||
| @@ -773,6 +773,11 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], | 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], | ||||
| 'desc_bprop': [[4, 1, 3]], | 'desc_bprop': [[4, 1, 3]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('UnsortedSegmentMin', { | |||||
| 'block': P.UnsortedSegmentMin(), | |||||
| 'desc_const': [4], | |||||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], | |||||
| 'desc_bprop': [[4, 2, 1, 3]]}), | |||||
| ('DropoutGenMask', { | ('DropoutGenMask', { | ||||
| 'block': P.DropoutGenMask(), | 'block': P.DropoutGenMask(), | ||||
| 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], | 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], | ||||