Browse Source

!8669 Split op extending PrimitiveWithCheck

From: @tom__chen
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3e2eb70fc0
7 changed files with 134 additions and 27 deletions
  1. +0
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
  2. +2
    -0
      mindspore/core/abstract/infer_functions.h
  3. +43
    -0
      mindspore/core/abstract/prim_arrays.cc
  4. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  5. +1
    -0
      mindspore/core/base/core_ops.h
  6. +5
    -23
      mindspore/ops/operations/array_ops.py
  7. +82
    -0
      tests/st/ops/gpu/test_split.py

+ 0
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h View File

@@ -132,10 +132,6 @@ class SplitGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_];
return false;
}
if (input_shape[axis_] % output_num_ != 0) {
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_];
return false;
}
if (output_num_ != output_num) {
MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_;
return false;


+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -261,6 +261,8 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.


+ 43
- 0
mindspore/core/abstract/prim_arrays.cc View File

@@ -784,5 +784,48 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt
auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
return ret;
}

AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
ShapeVector x_shape = input_x->shape()->shape();
ShapeVector x_shape_min = input_x->shape()->min_shape();
if (x_shape_min.empty()) {
x_shape_min = x_shape;
}
ShapeVector x_shape_max = input_x->shape()->max_shape();
if (x_shape_max.empty()) {
x_shape_max = x_shape;
}
int64_t rank = SizeToLong(x_shape.size());

ValuePtr axis = primitive->GetAttr("axis");
int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank);
axis_value = GetPositiveAxis(axis_value, LongToSize(rank));
int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value();
if ((x_shape[axis_value] != Shape::SHP_ANY) && (x_shape[axis_value] % output_num_value != 0)) {
MS_LOG(EXCEPTION) << "x_shape[" << axis_value << "] = " << x_shape[axis_value]
<< " must be divisible by output_num = " << output_num_value;
}

ShapeVector output_shape = x_shape;
if (output_shape[axis_value] != Shape::SHP_ANY) {
output_shape[axis_value] = static_cast<int>(x_shape[axis_value] / output_num_value);
}
ShapeVector output_shape_min = x_shape_min;
output_shape_min[axis_value] = static_cast<int>(x_shape_min[axis_value] / output_num_value);
ShapeVector output_shape_max = x_shape_max;
output_shape_max[axis_value] = static_cast<int>(x_shape_max[axis_value] / output_num_value);

AbstractBasePtrList output_list;
for (int64_t i = 0; i < output_num_value; ++i) {
auto output = input_x->Broaden();
output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max));
output_list.push_back(output);
}
return std::make_shared<AbstractTuple>(output_list);
}
} // namespace abstract
} // namespace mindspore

+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -70,6 +70,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -118,6 +118,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam
inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad");
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");

// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");


+ 5
- 23
mindspore/ops/operations/array_ops.py View File

@@ -876,18 +876,17 @@ class UniqueWithPad(PrimitiveWithInfer):
return out


class Split(PrimitiveWithInfer):
class Split(PrimitiveWithCheck):
"""
Splits the input tensor into output_num of tensors along the given axis and output numbers.

Args:
axis (int): Index of the split position. Default: 0.
output_num (int): The number of output tensors. Default: 1.
output_num (int): The number of output tensors. Must be postive int. Default: 1.

Raises:
ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)),
or if the `output_num` is less than or equal to 0, or if the
dimension which to split cannot be evenly divided by `output_num`.
or if the `output_num` is less than or equal to 0.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -916,32 +915,15 @@ class Split(PrimitiveWithInfer):
"""Initialize Split"""
validator.check_value_type("axis", axis, [int], self.name)
validator.check_value_type("output_num", output_num, [int], self.name)
validator.check_positive_int(output_num, "output_num", self.name)
self.axis = axis
self.output_num = output_num

def __infer__(self, x):
def __check__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
validator.check_positive_int(self.output_num, "output_num", self.name)
output_valid_check = x_shape[self.axis] % self.output_num
if output_valid_check != 0:
raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
f" output_num {self.output_num}")

x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
out_shapes = []
out_dtypes = []
for _ in range(self.output_num):
out_shapes.append(tuple(x_shape))
out_dtypes.append(x['dtype'])
out_shapes = tuple(out_shapes)
out_dtypes = tuple(out_dtypes)
out = {'shape': out_shapes,
'dtype': out_dtypes,
'value': None}
return out


class Rank(PrimitiveWithInfer):


+ 82
- 0
tests/st/ops/gpu/test_split.py View File

@@ -18,6 +18,7 @@ import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import operations as P


@@ -30,6 +31,18 @@ class Net(nn.Cell):
return self.split(x)


class NetDynamic(nn.Cell):
def __init__(self, axis=0, out_nums=1):
super(NetDynamic, self).__init__()
self.conv = inner.GpuConvertToDynamicShape()
self.split = P.Split(axis, out_nums)

def construct(self, x):
x_conv = self.conv(x)
x_split = self.split(x_conv)
return x_split


context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


@@ -47,6 +60,9 @@ def test_split():
assert (out.asnumpy() == x[i]).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_4d():
x_np = np.random.randn(2, 6, 4, 4).astype(np.float32)
y = np.split(x_np, 3, axis=1)
@@ -56,3 +72,69 @@ def test_split_4d():

for i, out in enumerate(outputs):
assert (out.asnumpy() == y[i]).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.float32)

net = NetDynamic(0, 3)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == x[i]).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic_axis1():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.int32)
y = np.split(x, 2, axis=1)

net = NetDynamic(1, 2)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == y[i]).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_dynamic_axis2():
x = np.array([[[1, -1, 1], [2, -2, 2]],
[[3, -3, 3], [4, -4, 4]],
[[5, -5, 5], [6, -6, 6]]]).astype(np.int32)
y = np.split(x, 3, axis=2)

net = NetDynamic(2, 3)
x_split = net(Tensor(x))
for i, out in enumerate(x_split):
assert (out.asnumpy() == y[i]).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_split_invalid_input():
with pytest.raises(TypeError):
_ = Net(0.1, 3)

with pytest.raises(TypeError):
_ = Net(0, 3.0)

with pytest.raises(ValueError):
_ = Net(0, -3)

x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
split_net = Net(2, 2)
with pytest.raises(ValueError):
_ = split_net(Tensor(x))

with pytest.raises(TypeError):
_ = split_net(x)

Loading…
Cancel
Save