| @@ -138,6 +138,7 @@ const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size"); | |||
| const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax"); | |||
| const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack"); | |||
| const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum"); | |||
| const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||
| const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||
| const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||
| @@ -143,6 +143,7 @@ extern const PrimitivePtr kPrimSize; | |||
| extern const PrimitivePtr kPrimArgMax; | |||
| extern const PrimitivePtr kPrimPack; | |||
| extern const PrimitivePtr kPrimUnpack; | |||
| extern const PrimitivePtr kPrimUnsortedSegmentMin; | |||
| extern const PrimitivePtr kPrimUnsortedSegmentSum; | |||
| extern const PrimitivePtr kPrimConcatOffset; | |||
| extern const PrimitivePtr kPrimReshape; | |||
| @@ -340,6 +340,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, | |||
| {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, | |||
| {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, | |||
| {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMinD)}, | |||
| {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, | |||
| {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, | |||
| {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, | |||
| @@ -1048,6 +1048,12 @@ INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int | |||
| ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; | |||
| // UnsortedSegmentMin | |||
| INPUT_MAP(UnsortedSegmentMinD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; | |||
| INPUT_ATTR_MAP(UnsortedSegmentMinD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}}; | |||
| ATTR_MAP(UnsortedSegmentMinD) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(UnsortedSegmentMinD) = {{0, OUTPUT_DESC(y)}}; | |||
| // ExpandDims | |||
| INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; | |||
| ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; | |||
| @@ -281,6 +281,9 @@ DECLARE_OP_USE_OUTPUT(StridedSlice) | |||
| DECLARE_OP_ADAPTER(UnsortedSegmentSumD) | |||
| DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) | |||
| DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) | |||
| DECLARE_OP_ADAPTER(UnsortedSegmentMinD) | |||
| DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMinD) | |||
| DECLARE_OP_USE_OUTPUT(UnsortedSegmentMinD) | |||
| DECLARE_OP_ADAPTER(ExpandDims) | |||
| DECLARE_OP_USE_OUTPUT(ExpandDims) | |||
| DECLARE_OP_ADAPTER(Squeeze) | |||
| @@ -22,6 +22,7 @@ from .. import functional as F | |||
| from .grad_base import bprop_getters | |||
| from ..primitive import constexpr | |||
| from ... import context | |||
| from ...common import dtype as mstype | |||
| reduce_sum = P.ReduceSum() | |||
| unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| @@ -29,6 +30,7 @@ transpose = P.Transpose() | |||
| shape_op = P.Shape() | |||
| reshape = P.Reshape() | |||
| invert_permutation = P.InvertPermutation() | |||
| logical_and = P.LogicalAnd() | |||
| @bprop_getters.register(P.Fill) | |||
| @@ -456,6 +458,57 @@ def get_bprop_diag_part(self): | |||
| return bprop | |||
| def _GatherDropNegatives(params, | |||
| ids, | |||
| zero_clipped_indices=None, | |||
| is_positive=None): | |||
| """Helper function for unsorted segment ops.""" | |||
| maximum = P.Maximum() | |||
| gather = P.GatherV2() | |||
| greater_equal = P.GreaterEqual() | |||
| rank = P.Rank() | |||
| fill = P.Fill() | |||
| select = P.Select() | |||
| if zero_clipped_indices is None: | |||
| zero_clipped_indices = maximum(ids, zeros_like(ids)) | |||
| gathered = gather(params, zero_clipped_indices, 0) | |||
| if is_positive is None: | |||
| is_positive = greater_equal(ids, 0) | |||
| is_positive_shape = shape_op(is_positive) | |||
| broadcastable_shape = is_positive_shape | |||
| for _ in range(rank(gathered) - rank(is_positive)): | |||
| broadcastable_shape += (1,) | |||
| is_positive = reshape(is_positive, broadcastable_shape) | |||
| gathered_shape = shape_op(gathered) | |||
| is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1)) | |||
| zero_slice = zeros_like(gathered) | |||
| return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) | |||
| @bprop_getters.register(P.UnsortedSegmentMin) | |||
| def get_bprop_unsorted_segment_min(self): | |||
| """Generate bprop for UnsortedSegmentMin""" | |||
| equal = P.Equal() | |||
| cast = P.Cast() | |||
| divide = P.RealDiv() | |||
| get_dtype = P.DType() | |||
| select = P.Select() | |||
| def bprop(x, segment_ids, num_segments, out, dout): | |||
| gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids) | |||
| is_selected = equal(x, gathered_outputs) | |||
| is_selected = logical_and(is_selected, is_positive) | |||
| num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), | |||
| segment_ids, num_segments) | |||
| weighted_grads = divide(dout, num_selected) | |||
| gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, | |||
| zero_clipped_indices, is_positive) | |||
| zeros = zeros_like(gathered_grads) | |||
| return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) | |||
| return bprop | |||
| @bprop_getters.register(P.SpaceToBatch) | |||
| def get_bprop_space_to_batch(self): | |||
| """Generate bprop for SpaceToBatch""" | |||
| @@ -28,7 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, | |||
| Squeeze, StridedSlice, Tile, | |||
| Transpose, TruncatedNormal, TupleToArray, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| @@ -96,6 +96,7 @@ __all__ = [ | |||
| 'MaxPool', | |||
| 'TopK', | |||
| 'Adam', | |||
| 'Softplus', | |||
| 'Softmax', | |||
| 'LogSoftmax', | |||
| 'SoftmaxCrossEntropyWithLogits', | |||
| @@ -210,6 +211,7 @@ __all__ = [ | |||
| 'Size', | |||
| 'DepthwiseConv2dNative', | |||
| 'UnsortedSegmentSum', | |||
| 'UnsortedSegmentMin', | |||
| "AllGather", | |||
| "AllReduce", | |||
| "ReduceScatter", | |||
| @@ -1253,6 +1253,54 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| return out | |||
| class UnsortedSegmentMin(PrimitiveWithInfer): | |||
| """ | |||
| Computes the minimum along segments of a tensor. | |||
| If the given segment_ids is negative, the value will be ignored. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. | |||
| - **segment_ids** (Tensor) - A `1-D` tensor whose shape is a prefix of `x_shape`. | |||
| - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`. | |||
| Outputs: | |||
| Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) | |||
| >>> segment_ids = Tensor(np.array([0, 1, 1]).np.int32) | |||
| >>> num_segments = 2 | |||
| >>> unsorted_segment_min = P.UnsortedSegmentMin() | |||
| >>> unsorted_segment_min(input_x, segment_ids, num_segments) | |||
| [[1., 2., 3.], [4., 2., 1.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init UnsortedSegmentMin""" | |||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | |||
| def __infer__(self, x, segment_ids, num_segments): | |||
| x_type = x['dtype'] | |||
| x_shape = x['shape'] | |||
| segment_ids_shape = segment_ids['shape'] | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) | |||
| num_segments_v = num_segments['value'] | |||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | |||
| validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) | |||
| segment_ids_shape_len = len(segment_ids_shape) | |||
| out_shape = [num_segments_v] | |||
| out_shape += x_shape[segment_ids_shape_len:] | |||
| out = {'shape': out_shape, | |||
| 'dtype': x_type, | |||
| 'value': None} | |||
| return out | |||
| class Concat(PrimitiveWithInfer): | |||
| r""" | |||
| Concat tensor in specified axis. | |||
| @@ -773,6 +773,11 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([[0, 1], [0, 1], [0, 1]]).astype(np.int32))], | |||
| 'desc_bprop': [[4, 1, 3]], | |||
| 'skip': ['backward']}), | |||
| ('UnsortedSegmentMin', { | |||
| 'block': P.UnsortedSegmentMin(), | |||
| 'desc_const': [4], | |||
| 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], | |||
| 'desc_bprop': [[4, 2, 1, 3]]}), | |||
| ('DropoutGenMask', { | |||
| 'block': P.DropoutGenMask(), | |||
| 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], | |||