From: @tom__chen Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -132,10 +132,6 @@ class SplitGpuFwdKernel : public GpuKernel { | |||||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; | MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_]; | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (input_shape[axis_] % output_num_ != 0) { | |||||
| MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_]; | |||||
| return false; | |||||
| } | |||||
| if (output_num_ != output_num) { | if (output_num_ != output_num) { | ||||
| MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; | MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_; | ||||
| return false; | return false; | ||||
| @@ -261,6 +261,8 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -784,5 +784,48 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt | |||||
| auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max)); | auto ret = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, shape_min, shape_max)); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| ShapeVector x_shape = input_x->shape()->shape(); | |||||
| ShapeVector x_shape_min = input_x->shape()->min_shape(); | |||||
| if (x_shape_min.empty()) { | |||||
| x_shape_min = x_shape; | |||||
| } | |||||
| ShapeVector x_shape_max = input_x->shape()->max_shape(); | |||||
| if (x_shape_max.empty()) { | |||||
| x_shape_max = x_shape; | |||||
| } | |||||
| int64_t rank = SizeToLong(x_shape.size()); | |||||
| ValuePtr axis = primitive->GetAttr("axis"); | |||||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); | |||||
| axis_value = GetPositiveAxis(axis_value, LongToSize(rank)); | |||||
| int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value(); | |||||
| if ((x_shape[axis_value] != Shape::SHP_ANY) && (x_shape[axis_value] % output_num_value != 0)) { | |||||
| MS_LOG(EXCEPTION) << "x_shape[" << axis_value << "] = " << x_shape[axis_value] | |||||
| << " must be divisible by output_num = " << output_num_value; | |||||
| } | |||||
| ShapeVector output_shape = x_shape; | |||||
| if (output_shape[axis_value] != Shape::SHP_ANY) { | |||||
| output_shape[axis_value] = static_cast<int>(x_shape[axis_value] / output_num_value); | |||||
| } | |||||
| ShapeVector output_shape_min = x_shape_min; | |||||
| output_shape_min[axis_value] = static_cast<int>(x_shape_min[axis_value] / output_num_value); | |||||
| ShapeVector output_shape_max = x_shape_max; | |||||
| output_shape_max[axis_value] = static_cast<int>(x_shape_max[axis_value] / output_num_value); | |||||
| AbstractBasePtrList output_list; | |||||
| for (int64_t i = 0; i < output_num_value; ++i) { | |||||
| auto output = input_x->Broaden(); | |||||
| output->set_shape(std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max)); | |||||
| output_list.push_back(output); | |||||
| } | |||||
| return std::make_shared<AbstractTuple>(output_list); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,6 +70,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | ||||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | {prim::kPrimTranspose, {InferImplTranspose, true}}, | ||||
| {prim::kPrimReshape, {InferImplReshape, true}}, | {prim::kPrimReshape, {InferImplReshape, true}}, | ||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | {prim::kPrimMakeList, {InferImplMakeList, true}}, | ||||
| @@ -118,6 +118,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam | |||||
| inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad"); | inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad"); | ||||
| inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | ||||
| inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | ||||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| @@ -876,18 +876,17 @@ class UniqueWithPad(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Split(PrimitiveWithInfer): | |||||
| class Split(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Splits the input tensor into output_num of tensors along the given axis and output numbers. | Splits the input tensor into output_num of tensors along the given axis and output numbers. | ||||
| Args: | Args: | ||||
| axis (int): Index of the split position. Default: 0. | axis (int): Index of the split position. Default: 0. | ||||
| output_num (int): The number of output tensors. Default: 1. | |||||
| output_num (int): The number of output tensors. Must be postive int. Default: 1. | |||||
| Raises: | Raises: | ||||
| ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), | ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)), | ||||
| or if the `output_num` is less than or equal to 0, or if the | |||||
| dimension which to split cannot be evenly divided by `output_num`. | |||||
| or if the `output_num` is less than or equal to 0. | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | ||||
| @@ -916,32 +915,15 @@ class Split(PrimitiveWithInfer): | |||||
| """Initialize Split""" | """Initialize Split""" | ||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| validator.check_value_type("output_num", output_num, [int], self.name) | validator.check_value_type("output_num", output_num, [int], self.name) | ||||
| validator.check_positive_int(output_num, "output_num", self.name) | |||||
| self.axis = axis | self.axis = axis | ||||
| self.output_num = output_num | self.output_num = output_num | ||||
| def __infer__(self, x): | |||||
| def __check__(self, x): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | ||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| dim = len(x_shape) | dim = len(x_shape) | ||||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | ||||
| validator.check_positive_int(self.output_num, "output_num", self.name) | |||||
| output_valid_check = x_shape[self.axis] % self.output_num | |||||
| if output_valid_check != 0: | |||||
| raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | |||||
| f" output_num {self.output_num}") | |||||
| x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) | |||||
| out_shapes = [] | |||||
| out_dtypes = [] | |||||
| for _ in range(self.output_num): | |||||
| out_shapes.append(tuple(x_shape)) | |||||
| out_dtypes.append(x['dtype']) | |||||
| out_shapes = tuple(out_shapes) | |||||
| out_dtypes = tuple(out_dtypes) | |||||
| out = {'shape': out_shapes, | |||||
| 'dtype': out_dtypes, | |||||
| 'value': None} | |||||
| return out | |||||
| class Rank(PrimitiveWithInfer): | class Rank(PrimitiveWithInfer): | ||||
| @@ -18,6 +18,7 @@ import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -30,6 +31,18 @@ class Net(nn.Cell): | |||||
| return self.split(x) | return self.split(x) | ||||
| class NetDynamic(nn.Cell): | |||||
| def __init__(self, axis=0, out_nums=1): | |||||
| super(NetDynamic, self).__init__() | |||||
| self.conv = inner.GpuConvertToDynamicShape() | |||||
| self.split = P.Split(axis, out_nums) | |||||
| def construct(self, x): | |||||
| x_conv = self.conv(x) | |||||
| x_split = self.split(x_conv) | |||||
| return x_split | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| @@ -47,6 +60,9 @@ def test_split(): | |||||
| assert (out.asnumpy() == x[i]).all() | assert (out.asnumpy() == x[i]).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_split_4d(): | def test_split_4d(): | ||||
| x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) | x_np = np.random.randn(2, 6, 4, 4).astype(np.float32) | ||||
| y = np.split(x_np, 3, axis=1) | y = np.split(x_np, 3, axis=1) | ||||
| @@ -56,3 +72,69 @@ def test_split_4d(): | |||||
| for i, out in enumerate(outputs): | for i, out in enumerate(outputs): | ||||
| assert (out.asnumpy() == y[i]).all() | assert (out.asnumpy() == y[i]).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_split_dynamic(): | |||||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||||
| [[3, -3, 3], [4, -4, 4]], | |||||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.float32) | |||||
| net = NetDynamic(0, 3) | |||||
| x_split = net(Tensor(x)) | |||||
| for i, out in enumerate(x_split): | |||||
| assert (out.asnumpy() == x[i]).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_split_dynamic_axis1(): | |||||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||||
| [[3, -3, 3], [4, -4, 4]], | |||||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) | |||||
| y = np.split(x, 2, axis=1) | |||||
| net = NetDynamic(1, 2) | |||||
| x_split = net(Tensor(x)) | |||||
| for i, out in enumerate(x_split): | |||||
| assert (out.asnumpy() == y[i]).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_split_dynamic_axis2(): | |||||
| x = np.array([[[1, -1, 1], [2, -2, 2]], | |||||
| [[3, -3, 3], [4, -4, 4]], | |||||
| [[5, -5, 5], [6, -6, 6]]]).astype(np.int32) | |||||
| y = np.split(x, 3, axis=2) | |||||
| net = NetDynamic(2, 3) | |||||
| x_split = net(Tensor(x)) | |||||
| for i, out in enumerate(x_split): | |||||
| assert (out.asnumpy() == y[i]).all() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_split_invalid_input(): | |||||
| with pytest.raises(TypeError): | |||||
| _ = Net(0.1, 3) | |||||
| with pytest.raises(TypeError): | |||||
| _ = Net(0, 3.0) | |||||
| with pytest.raises(ValueError): | |||||
| _ = Net(0, -3) | |||||
| x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32) | |||||
| split_net = Net(2, 2) | |||||
| with pytest.raises(ValueError): | |||||
| _ = split_net(Tensor(x)) | |||||
| with pytest.raises(TypeError): | |||||
| _ = split_net(x) | |||||