| @@ -80,6 +80,8 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"concat", "concat_d"}, | |||
| {"slice", "slice_d"}, | |||
| {"reduce_sum", "reduce_sum_d"}, | |||
| {"inplace_add", "inplace_add_d"}, | |||
| {"inplace_sub", "inplace_sub_d"}, | |||
| {"one_hot", "one_hot_d"}, | |||
| {"sum", "reduce_sum_d"}, | |||
| {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, | |||
| @@ -171,6 +171,8 @@ const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); | |||
| const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); | |||
| const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); | |||
| const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); | |||
| const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); | |||
| const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | |||
| // NN | |||
| const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -180,6 +180,8 @@ extern const PrimitivePtr kPrimLessEqual; | |||
| extern const PrimitivePtr kPrimCumSum; | |||
| extern const PrimitivePtr kPrimCumProd; | |||
| extern const PrimitivePtr kPrimSubscalar; | |||
| extern const PrimitivePtr kPrimInplaceAdd; | |||
| extern const PrimitivePtr kPrimInplaceSub; | |||
| // NN | |||
| extern const PrimitivePtr kPrimFlatten; | |||
| @@ -133,6 +133,8 @@ constexpr auto kResizeNearestNeighborV2OpName = "ResizeNearestNeighborV2"; | |||
| constexpr auto kResizeNearestNeighborV2GradOpName = "ResizeNearestNeighborV2Grad"; | |||
| constexpr auto kApplyRMSPropOpname = "ApplyRMSProp"; | |||
| constexpr auto kCumsumOpName = "Cumsum"; | |||
| constexpr auto kInplaceAddOpName = "InplaceAdd"; | |||
| constexpr auto kInplaceSubOpName = "InplaceSub"; | |||
| constexpr auto kResizeBilinearV2OpName = "kResizeBilinearV2"; | |||
| constexpr auto kReduceProdOpName = "ReduceProd"; | |||
| constexpr auto kCumprodOpName = "Cumprod"; | |||
| @@ -15,6 +15,8 @@ | |||
| """tbe ops""" | |||
| from .abs import _abs_tbe | |||
| from .inplace_add import _inplace_add_tbe | |||
| from .inplace_sub import _inplace_sub_tbe | |||
| from .abs_grad import _abs_grad_tbe | |||
| from .acos import _acos_tbe | |||
| from .acos_grad import _acos_grad_tbe | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InplaceAdd op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| inplace_add_op_info = TBERegOp("InplaceAdd") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("inplace_add_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("inplace_add_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("indices", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "v", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(inplace_add_op_info) | |||
| def _inplace_add_tbe(): | |||
| """InplaceAdd TBE register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """InplaceSub op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| inplace_sub_op_info = TBERegOp("InplaceSub") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("inplace_sub_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("inplace_sub_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("indices", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "v", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(inplace_sub_op_info) | |||
| def _inplace_sub_tbe(): | |||
| """InplaceSub TBE register""" | |||
| return | |||
| @@ -41,7 +41,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast | |||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, | |||
| BitwiseXor, Inv, Invert, ApproximateEqual, | |||
| BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, | |||
| ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, | |||
| Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, | |||
| Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, | |||
| @@ -178,6 +178,8 @@ __all__ = [ | |||
| 'DropoutGrad', | |||
| 'Dropout', | |||
| 'Neg', | |||
| 'InplaceAdd', | |||
| 'InplaceSub', | |||
| 'Slice', | |||
| 'DType', | |||
| 'NPUAllocFloatStatus', | |||
| @@ -772,6 +772,125 @@ class Neg(PrimitiveWithInfer): | |||
| return input_x | |||
| class InplaceAdd(PrimitiveWithInfer): | |||
| """ | |||
| Adds v into specified rows of x. Computes y = x; y[i,] += v. | |||
| Args: | |||
| - **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x | |||
| to add with v. It is a int or tuple, whose value is in [0, the first dimension size of x). | |||
| Inputs: | |||
| - **input_x** (Tensor) - The first input is a tensor whose data type is number. | |||
| - **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except | |||
| the first dimension, which must be the same as indices's size. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as input. | |||
| Examples: | |||
| >>> indices = [0, 1] | |||
| >>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32) | |||
| >>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32) | |||
| >>> inplaceAdd = P.InplaceAdd(indices) | |||
| >>> inplaceAdd(input_x, input_v) | |||
| [[1.5 3.] | |||
| [4. 5.5] | |||
| [5. 6.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, indices): | |||
| """init InplaceAdd""" | |||
| self.init_prim_io_names(inputs=['x', 'v'], outputs=['y']) | |||
| self.indices = indices | |||
| def infer_shape(self, x_shape, v_shape): | |||
| validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name) | |||
| if isinstance(self.indices, int): | |||
| validator.check("size of indices", 1, "v's first dimension", v_shape[0], | |||
| Rel.EQ, self.name) | |||
| if self.indices < 0 or self.indices >= x_shape[0]: | |||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.') | |||
| else: | |||
| validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0], | |||
| Rel.EQ, self.name) | |||
| for i in self.indices: | |||
| if i < 0 or i >= x_shape[0]: | |||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | |||
| if len(x_shape) > 1: | |||
| validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:], | |||
| Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, v_dtype): | |||
| args = {'x': x_dtype, 'v': v_dtype} | |||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_value_type('indices', self.indices, [tuple, int], self.name) | |||
| return x_dtype | |||
| class InplaceSub(PrimitiveWithInfer): | |||
| """ | |||
| Subtracts v into specified rows of x. Computes y = x; y[i, :] -= v; return y. | |||
| Args: | |||
| - **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x | |||
| to sub with v. It is a int or tuple, whose value is in [0, the first dimension size of x). | |||
| Inputs: | |||
| - **input_x** (Tensor) - The first input is a tensor whose data type is number. | |||
| - **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except | |||
| the first dimension, which must be the same as indices's size. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as input. | |||
| Examples: | |||
| >>> indices = [0, 1] | |||
| >>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32) | |||
| >>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32) | |||
| >>> inplaceSub = P.InplaceSub(indices) | |||
| >>> inplaceSub(input_x, input_v) | |||
| [[0.5 1.] | |||
| [2. 2.5] | |||
| [5. 6.]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, indices): | |||
| """init InplaceSub""" | |||
| self.init_prim_io_names(inputs=['x', 'v'], outputs=['y']) | |||
| self.indices = indices | |||
| def infer_shape(self, x_shape, v_shape): | |||
| validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name) | |||
| if isinstance(self.indices, int): | |||
| validator.check("size of indices", 1, "v's first dimension", v_shape[0], | |||
| Rel.EQ, self.name) | |||
| if self.indices < 0 or self.indices >= x_shape[0]: | |||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.') | |||
| else: | |||
| validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0], | |||
| Rel.EQ, self.name) | |||
| for i in self.indices: | |||
| if i < 0 or i >= x_shape[0]: | |||
| raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') | |||
| if len(x_shape) > 1: | |||
| validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:], | |||
| Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, v_dtype): | |||
| args = {'x': x_dtype, 'v': v_dtype} | |||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_value_type('indices', self.indices, [tuple, int], self.name) | |||
| return x_dtype | |||
| class Sub(_MathBinaryOp): | |||
| """ | |||
| Subtracts the second input tensor from the first input tensor element-wise. | |||
| @@ -367,6 +367,26 @@ class ApplyRMSNet(nn.Cell): | |||
| return out | |||
| class InplaceAddNet(nn.Cell): | |||
| def __init__(self): | |||
| super(InplaceAddNet, self).__init__() | |||
| self.inplace_add = P.InplaceAdd(indices=(0, 1)) | |||
| def construct(self, x, v): | |||
| out = self.inplace_add(x, v) | |||
| return out | |||
| class InplaceSubNet(nn.Cell): | |||
| def __init__(self): | |||
| super(InplaceSubNet, self).__init__() | |||
| self.inplace_sub = P.InplaceSub(indices=(0, 1)) | |||
| def construct(self, x, v): | |||
| out = self.inplace_sub(x, v) | |||
| return out | |||
| test_case_math_ops = [ | |||
| ('BitwiseAnd', { | |||
| 'block': P.BitwiseAnd(), | |||
| @@ -492,6 +512,16 @@ test_case_math_ops = [ | |||
| 'desc_inputs': [[2, 512, 56, 56]], | |||
| 'desc_bprop': [[2, 512, 56, 56]], | |||
| 'skip': ['backward']}), | |||
| ('InplaceAdd', { | |||
| 'block': InplaceAddNet(), | |||
| 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)), | |||
| Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('InplaceSub', { | |||
| 'block': InplaceSubNet(), | |||
| 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)), | |||
| Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('ACos', { | |||
| 'block': P.ACos(), | |||
| 'desc_inputs': [Tensor(np.array([2., 3.]).astype(np.float32))], | |||