Merge pull request !3846 from liuxiao93/Add-ops-SeluSquaredDifferencetags/v0.7.0-beta
| @@ -252,6 +252,21 @@ def get_bprop_div_no_nan(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Xdivy) | |||
| def get_bprop_xdivy(self): | |||
| """Grad definition for `Xdivy` operation.""" | |||
| div_op = P.Xdivy() | |||
| def bprop(x, y, out, dout): | |||
| x_dtype = F.dtype(x) | |||
| not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) | |||
| bc_x = div_op(not_zero_x, y) * dout | |||
| bc_y = div_op(-x, F.square(y)) * dout | |||
| return binop_grad_common(x, y, bc_x, bc_y) | |||
| return bprop | |||
| @bprop_getters.register(P.Floor) | |||
| def get_bprop_floor(self): | |||
| """Grad definition for `floor` operation.""" | |||
| @@ -353,6 +368,36 @@ def get_bprop_square(self): | |||
| return bprop | |||
| @bprop_getters.register(P.SquaredDifference) | |||
| def get_bprop_squared_difference(self): | |||
| """Grad definition for `SquaredDifference` operation.""" | |||
| neg = P.Neg() | |||
| def bprop(x, y, out, dout): | |||
| x_grad = 2 * dout * (x - y) | |||
| bc_x = x_grad | |||
| bc_y = neg(x_grad) | |||
| return binop_grad_common(x, y, bc_x, bc_y) | |||
| return bprop | |||
| @bprop_getters.register(P.Xlogy) | |||
| def get_bprop_xlogy(self): | |||
| """Grad definition for `Xlogy` operation.""" | |||
| log_op = P.Xlogy() | |||
| div_op = P.Xdivy() | |||
| def bprop(x, y, out, dout): | |||
| x_dtype = F.dtype(x) | |||
| not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype) | |||
| bc_x = log_op(not_zero_x, y) * dout | |||
| bc_y = div_op(x, y) * dout | |||
| return binop_grad_common(x, y, bc_x, bc_y) | |||
| return bprop | |||
| @bprop_getters.register(P.Sqrt) | |||
| def get_bprop_sqrt(self): | |||
| """Grad definition for `Sqrt` operation.""" | |||
| @@ -108,6 +108,8 @@ from .elu import _elu_tbe | |||
| from .elu_grad import _elu_grad_tbe | |||
| from .div import _div_tbe | |||
| from .log import _log_tbe | |||
| from .xdivy import _xdivy_tbe | |||
| from .xlogy import _xlogy_tbe | |||
| from .floor_div import _floor_div_tbe | |||
| from .zeros_like import _zeros_like_tbe | |||
| from .neg import _neg_tbe | |||
| @@ -133,6 +135,7 @@ from .softplus import _softplus_tbe | |||
| from .softplus_grad import _softplus_grad_tbe | |||
| from .softmax_grad_ext import _softmax_grad_ext_tbe | |||
| from .square import _square_tbe | |||
| from .squared_difference import _squared_difference_tbe | |||
| from .sqrt import _sqrt_tbe | |||
| from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d | |||
| from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SquaredDifference op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| squared_difference_op_info = TBERegOp("SquaredDifference") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("squared_difference.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("squared_difference") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("broadcast") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(squared_difference_op_info) | |||
| def _squared_difference_tbe(): | |||
| """SquaredDifference TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Xdivy op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| xdivy_op_info = TBERegOp("Xdivy") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("xdivy.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("xdivy") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(xdivy_op_info) | |||
| def _xdivy_tbe(): | |||
| """Xdivy TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Xlogy op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| xlogy_op_info = TBERegOp("Xlogy") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("xlogy.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("xlogy") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(xlogy_op_info) | |||
| def _xlogy_tbe(): | |||
| """Xlogy TBE register""" | |||
| return | |||
| @@ -51,7 +51,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Minimum, Mul, Neg, NMSWithMask, NotEqual, | |||
| NPUAllocFloatStatus, NPUClearFloatStatus, | |||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | |||
| Reciprocal, CumSum, HistogramFixedWidth, | |||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | |||
| @@ -107,6 +107,9 @@ __all__ = [ | |||
| 'Rsqrt', | |||
| 'Sqrt', | |||
| 'Square', | |||
| 'SquaredDifference', | |||
| 'Xdivy', | |||
| 'Xlogy', | |||
| 'Conv2D', | |||
| 'Flatten', | |||
| 'MaxPoolWithArgmax', | |||
| @@ -1121,6 +1121,40 @@ class Mul(_MathBinaryOp): | |||
| return None | |||
| class SquaredDifference(_MathBinaryOp): | |||
| """ | |||
| Subtracts the second input tensor from the first input tensor element-wise and returns square of it. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar only could be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is float16, float32, int32 or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is | |||
| float16, float32, int32 or bool. | |||
| Outputs: | |||
| Tensor, the shape is same as the shape after broadcasting, | |||
| and the data type is the one with high precision or high digits among the two inputs. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) | |||
| >>> input_y = Tensor(np.array([2.0, 4.0, 6.0]), mindspore.float32) | |||
| >>> squared_difference = P.SquaredDifference() | |||
| >>> squared_difference(input_x, input_y) | |||
| [1.0, 4.0, 9.0] | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, valid_type, self.name) | |||
| class Square(PrimitiveWithInfer): | |||
| """ | |||
| Returns square of a tensor element-wise. | |||
| @@ -1962,6 +1996,72 @@ class Ceil(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Xdivy(_MathBinaryOp): | |||
| """ | |||
| Divide the first input tensor by the second input tensor element-wise. Returns zero when `x` is zero. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar only could be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is float16, float32 or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. | |||
| Outputs: | |||
| Tensor, the shape is same as the shape after broadcasting, | |||
| and the data type is the one with high precision or high digits among the two inputs. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.float32) | |||
| >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) | |||
| >>> xdivy = P.Xdivy() | |||
| >>> xdivy(input_x, input_y) | |||
| [1.0, 2.0, -0.5] | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) | |||
| class Xlogy(_MathBinaryOp): | |||
| """ | |||
| Computes first input tensor multiplied by the logarithm of second input tensor element-wise. | |||
| Returns zero when `x` is zero. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar only could be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is float16, float32 or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool. | |||
| The value must be positive. | |||
| Outputs: | |||
| Tensor, the shape is same as the shape after broadcasting, | |||
| and the data type is the one with high precision or high digits among the two inputs. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([-5, 0, 4]), mindspore.float32) | |||
| >>> input_y = Tensor(np.array([2, 2, 2]), mindspore.float32) | |||
| >>> xlogy = P.Xlogy() | |||
| >>> Xlogy(input_x, input_y) | |||
| [-3.465736, 0.0, 2.7725887] | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, [mstype.float16, mstype.float32], self.name) | |||
| class Acosh(PrimitiveWithInfer): | |||
| """ | |||
| Compute inverse hyperbolic cosine of x element-wise. | |||
| @@ -3205,11 +3205,11 @@ class FusedSparseFtrl(PrimitiveWithInfer): | |||
| use_locking (bool): Use locks for update operation if True . Default: False. | |||
| Inputs: | |||
| - **var** (Parameter): The variable to be updated. The data type must be float32. | |||
| - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. | |||
| - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape | |||
| - **var** (Parameter) - The variable to be updated. The data type must be float32. | |||
| - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The shape | |||
| of `indices` must be the same as `grad` in first dimension. The type must be int32. | |||
| Outputs: | |||
| @@ -3300,9 +3300,9 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): | |||
| Inputs: | |||
| - **var** (Parameter) - Variable tensor to be updated. The data type must be float32. | |||
| - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. | |||
| - **lr** (Tensor): The learning rate value. The data type must be float32. | |||
| - **l1** (Tensor): l1 regularization strength. The data type must be float32. | |||
| - **l2** (Tensor): l2 regularization strength. The data type must be float32. | |||
| - **lr** (Tensor) - The learning rate value. The data type must be float32. | |||
| - **l1** (Tensor) - l1 regularization strength. The data type must be float32. | |||
| - **l2** (Tensor) - l2 regularization strength. The data type must be float32. | |||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32. | |||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type | |||
| must be int32. | |||
| @@ -4670,16 +4670,16 @@ class ApplyFtrl(PrimitiveWithInfer): | |||
| use_locking (bool): Use locks for update operation if True . Default: False. | |||
| Inputs: | |||
| - **var** (Tensor): The variable to be updated. | |||
| - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor): Gradient. | |||
| - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. Default: 0.001. | |||
| - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. | |||
| - **var** (Tensor) - The variable to be updated. | |||
| - **accum** (Tensor) - The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Tensor) - The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor) - Gradient. | |||
| - **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001. | |||
| - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero. | |||
| Default: 0.0. | |||
| - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. | |||
| - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero. | |||
| Default: 0.0. | |||
| - **lr_power** (Union[Number, Tensor]): Learning rate power controls how the learning rate decreases | |||
| - **lr_power** (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases | |||
| during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero. | |||
| Default: -0.5. | |||
| @@ -4760,17 +4760,17 @@ class SparseApplyFtrl(PrimitiveWithInfer): | |||
| use_locking (bool): Use locks for update operation if True . Default: False. | |||
| Inputs: | |||
| - **var** (Parameter): The variable to be updated. The data type must be float32. | |||
| - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. | |||
| - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. | |||
| - **var** (Parameter) - The variable to be updated. The data type must be float32. | |||
| - **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`. | |||
| - **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`. | |||
| - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. | |||
| - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. | |||
| The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. | |||
| Outputs: | |||
| - **var** (Tensor): Tensor, has the same shape and type as `var`. | |||
| - **accum** (Tensor): Tensor, has the same shape and type as `accum`. | |||
| - **linear** (Tensor): Tensor, has the same shape and type as `linear`. | |||
| - **var** (Tensor) - Tensor, has the same shape and type as `var`. | |||
| - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. | |||
| - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. | |||
| Examples: | |||
| >>> import mindspore | |||
| @@ -4858,9 +4858,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tuple of 3 Tensor, the updated parameters. | |||
| - **var** (Tensor): Tensor, has the same shape and type as `var`. | |||
| - **accum** (Tensor): Tensor, has the same shape and type as `accum`. | |||
| - **linear** (Tensor): Tensor, has the same shape and type as `linear`. | |||
| - **var** (Tensor) - Tensor, has the same shape and type as `var`. | |||
| - **accum** (Tensor) - Tensor, has the same shape and type as `accum`. | |||
| - **linear** (Tensor) - Tensor, has the same shape and type as `linear`. | |||
| Examples: | |||
| >>> import mindspore | |||
| @@ -1013,6 +1013,18 @@ test_case_math_ops = [ | |||
| 'desc_const': [(0, 3, 1, 2)], | |||
| 'desc_inputs': [], | |||
| 'skip': ['backward']}), | |||
| ('Xdivy', { | |||
| 'block': P.Xdivy(), | |||
| 'desc_inputs': [[4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('Xlogy', { | |||
| 'block': P.Xlogy(), | |||
| 'desc_inputs': [[4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('SquaredDifference', { | |||
| 'block': P.SquaredDifference(), | |||
| 'desc_inputs': [[4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('Square', { | |||
| 'block': P.Square(), | |||
| 'desc_inputs': [[4]], | |||