From: @tom__chen Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -132,10 +132,6 @@ class SplitGpuFwdKernel : public GpuKernel { | |||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; | |||
| return false; | |||
| } | |||
| if (input_shape[axis_] % output_num_ != 0) { | |||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; | |||
| return false; | |||
| } | |||
| if (output_num_ != output_num) { | |||
| MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; | |||
| return false; | |||
| @@ -261,6 +261,8 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -784,5 +784,48 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt | |||
| auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| ShapeVector x_shape = input_x->shape()->shape(); | |||
| ShapeVector x_shape_min = input_x->shape()->min_shape(); | |||
| if (x_shape_min.empty()) { | |||
| x_shape_min = x_shape; | |||
| } | |||
| ShapeVector x_shape_max = input_x->shape()->max_shape(); | |||
| if (x_shape_max.empty()) { | |||
| x_shape_max = x_shape; | |||
| } | |||
| int64_t rank = SizeToLong(x_shape.size()); | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); | |||
| axis_value = GetPositiveAxis(axis_value, LongToSize(rank)); | |||
| int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value(); | |||
| if ((x_shape[axis_value] != Shape::SHP_ANY) && (x_shape[axis_value] % output_num_value != 0)) { | |||
| MS_LOG(EXCEPTION) << "x_shape[" << axis_value << "] = " << x_shape[axis_value] | |||
| << " must be divisible by output_num = " << output_num_value; | |||
| } | |||
| ShapeVector output_shape = x_shape; | |||
| if (output_shape[axis_value] != Shape::SHP_ANY) { | |||
| output_shape[axis_value] = static_cast<int>(x_shape[axis_value] / output_num_value); | |||
| } | |||
| ShapeVector output_shape_min = x_shape_min; | |||
| output_shape_min[axis_value] = static_cast<int>(x_shape_min[axis_value] / output_num_value); | |||
| ShapeVector output_shape_max = x_shape_max; | |||
| output_shape_max[axis_value] = static_cast<int>(x_shape_max[axis_value] / output_num_value); | |||
| AbstractBasePtrList output_list; | |||
| for (int64_t i = 0; i < output_num_value; ++i) { | |||
| auto output = input_x->Broaden(); | |||
| output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max)); | |||
| output_list.push_back(output); | |||
| } | |||
| return std::make_shared<AbstractTuple>(output_list); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -70,6 +70,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -118,6 +118,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam | |||
| inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad"); | |||
| inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | |||
| inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | |||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -876,18 +876,17 @@ class UniqueWithPad(PrimitiveWithInfer): | |||
| return out | |||
| class Split(PrimitiveWithInfer): | |||
| class Split(PrimitiveWithCheck): | |||
| """ | |||
| Splits the input tensor into output_num of tensors along the given axis and output numbers. | |||
| Args: | |||
| axis (int): Index of the split position. Default: 0. | |||
| output_num (int): The number of output tensors. Default: 1. | |||
| output_num (int): The number of output tensors. Must be postive int. Default: 1. | |||
| Raises: | |||
| ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), | |||
| or if the `output_num` is less than or equal to 0, or if the | |||
| dimension which to split cannot be evenly divided by `output_num`. | |||
| or if the `output_num` is less than or equal to 0. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| @@ -916,32 +915,15 @@ class Split(PrimitiveWithInfer): | |||
| """Initialize Split""" | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| validator.check_value_type("output_num", output_num, [int], self.name) | |||
| validator.check_positive_int(output_num, "output_num", self.name) | |||
| self.axis = axis | |||
| self.output_num = output_num | |||
| def __infer__(self, x): | |||
| def __check__(self, x): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| x_shape = list(x['shape']) | |||
| dim = len(x_shape) | |||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||
| validator.check_positive_int(self.output_num, "output_num", self.name) | |||
| output_valid_check = x_shape[self.axis] % self.output_num | |||
| if output_valid_check != 0: | |||
| raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | |||
| f" output_num {self.output_num}") | |||
| x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) | |||
| out_shapes = [] | |||
| out_dtypes = [] | |||
| for _ in range(self.output_num): | |||
| out_shapes.append(tuple(x_shape)) | |||
| out_dtypes.append(x['dtype']) | |||
| out_shapes = tuple(out_shapes) | |||
| out_dtypes = tuple(out_dtypes) | |||
| out = {'shape': out_shapes, | |||
| 'dtype': out_dtypes, | |||
| 'value': None} | |||
| return out | |||
| class Rank(PrimitiveWithInfer): | |||
| @@ -18,6 +18,7 @@ import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops import operations as P | |||
| @@ -30,6 +31,18 @@ class Net(nn.Cell): | |||
| return self.split(x) | |||
| class NetDynamic(nn.Cell): | |||
| def __init__(self, axis=0, out_nums=1): | |||
| super(NetDynamic, self).__init__() | |||
| self.conv = inner.GpuConvertToDynamicShape() | |||
| self.split = P.Split(axis, out_nums) | |||
| def construct(self, x): | |||
| x_conv = self.conv(x) | |||
| x_split = self.split(x_conv) | |||
| return x_split | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| @@ -47,6 +60,9 @@ def test_split(): | |||
| assert (out.asnumpy() == x[i]).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split_4d(): | |||
| x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) | |||
| y = np.split(x_np, 3, axis=1) | |||
| @@ -56,3 +72,69 @@ def test_split_4d(): | |||
| for i, out in enumerate(outputs): | |||
| assert (out.asnumpy() == y[i]).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split_dynamic(): | |||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||
| [[3, -3, 3], [4, -4, 4]], | |||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) | |||
| net = NetDynamic(0, 3) | |||
| x_split = net(Tensor(x)) | |||
| for i, out in enumerate(x_split): | |||
| assert (out.asnumpy() == x[i]).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split_dynamic_axis1(): | |||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||
| [[3, -3, 3], [4, -4, 4]], | |||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) | |||
| y = np.split(x, 2, axis=1) | |||
| net = NetDynamic(1, 2) | |||
| x_split = net(Tensor(x)) | |||
| for i, out in enumerate(x_split): | |||
| assert (out.asnumpy() == y[i]).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split_dynamic_axis2(): | |||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||
| [[3, -3, 3], [4, -4, 4]], | |||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) | |||
| y = np.split(x, 3, axis=2) | |||
| net = NetDynamic(2, 3) | |||
| x_split = net(Tensor(x)) | |||
| for i, out in enumerate(x_split): | |||
| assert (out.asnumpy() == y[i]).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_split_invalid_input(): | |||
| with pytest.raises(TypeError): | |||
| _ = Net(0.1, 3) | |||
| with pytest.raises(TypeError): | |||
| _ = Net(0, 3.0) | |||
| with pytest.raises(ValueError): | |||
| _ = Net(0, -3) | |||
| x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) | |||
| split_net = Net(2, 2) | |||
| with pytest.raises(ValueError): | |||
| _ = split_net(Tensor(x)) | |||
| with pytest.raises(TypeError): | |||
| _ = split_net(x) | |||