Browse Source

Add TransShape Operator.

tags/v0.6.0-beta
leilei_snow leilei_snow 5 years ago
parent
commit
5cccfbc61b
7 changed files with 53 additions and 1 deletions
  1. +2
    -0
      mindspore/ccsrc/transform/convert.cc
  2. +6
    -0
      mindspore/ccsrc/transform/op_declare.cc
  3. +3
    -0
      mindspore/ccsrc/transform/op_declare.h
  4. +10
    -0
      mindspore/ops/_grad/grad_array_ops.py
  5. +1
    -1
      mindspore/ops/operations/__init__.py
  6. +25
    -0
      mindspore/ops/operations/array_ops.py
  7. +6
    -0
      tests/ut/python/ops/test_ops.py

+ 2
- 0
mindspore/ccsrc/transform/convert.cc View File

@@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
const char kNameReshape[] = "Reshape"; const char kNameReshape[] = "Reshape";
const char kNameTransShape[] = "TransShape";
const char kNameRealDiv[] = "RealDiv"; const char kNameRealDiv[] = "RealDiv";
const char kNameTile[] = "Tile"; const char kNameTile[] = "Tile";
const char kNameCos[] = "Cos"; const char kNameCos[] = "Cos";
@@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, {string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)}, {string(kNameReshape), ADPT_DESC(Reshape)},
{string(kNameTransShape), ADPT_DESC(TransShape)},
{string(kNameFlattenGrad), ADPT_DESC(Reshape)}, {string(kNameFlattenGrad), ADPT_DESC(Reshape)},
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
{string(kNameAddN), ADPT_DESC(AddN)}, {string(kNameAddN), ADPT_DESC(AddN)},


+ 6
- 0
mindspore/ccsrc/transform/op_declare.cc View File

@@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};


// TransShape
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};

// BiasAdd // BiasAdd
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};


+ 3
- 0
mindspore/ccsrc/transform/op_declare.h View File

@@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Reshape) DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape) DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(TransShape)
DECLARE_OP_USE_INPUT_ATTR(TransShape)
DECLARE_OP_USE_OUTPUT(TransShape)
DECLARE_OP_ADAPTER(Iou) DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou) DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)


+ 10
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
dx = reverse_sequence_grad(dout, seq_lengths) dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths) return dx, zeros_like(seq_lengths)
return bprop return bprop


@bprop_getters.register(P.TransShape)
def get_bprop_trans_shape(self):
"""Generate bprop for TransShape"""
op = P.TransShape()
def bprop(x, shape, out, dout):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop

+ 1
- 1
mindspore/ops/operations/__init__.py View File

@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,


+ 25
- 0
mindspore/ops/operations/array_ops.py View File

@@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x return x


class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.

Inputs:
- **input_x** (Tensor) - A input tensor.
- **out_shape** (tuple[int]) - The shape of output data.

Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True

def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
'value': None}

+ 6
- 0
tests/ut/python/ops/test_ops.py View File

@@ -1865,6 +1865,12 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'], 'skip': ['backward'],
}), }),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],
'desc_inputs': [[1, 3, 24, 24]],
'desc_bprop': [[1, 12, 24, 24]],
}),
] ]


test_case_other_ops = [ test_case_other_ops = [


Loading…
Cancel
Save