| @@ -306,6 +306,34 @@ def get_bprop_floormod(self): | |||
| return bprop | |||
| @bprop_getters.register(P.TruncateDiv) | |||
| def get_bprop_truncate_div(self): | |||
| """Grad definition for `TruncateDiv` operation.""" | |||
| div_op = P.TruncateDiv() | |||
| neg = P.Neg() | |||
| mul_op = P.Mul() | |||
| def bprop(x, y, out, dout): | |||
| bc_x = div_op(dout, y) | |||
| bc_y = neg(mul_op(bc_x, out)) | |||
| return binop_grad_common(x, y, bc_x, bc_y) | |||
| return bprop | |||
| @bprop_getters.register(P.TruncateMod) | |||
| def get_bprop_truncate_mod(self): | |||
| """Grad definition for `TruncateMod` operation.""" | |||
| div_op = P.TruncateDiv() | |||
| def bprop(x, y, out, dout): | |||
| bc_x = dout | |||
| bc_y = -dout * div_op(x, y) | |||
| return binop_grad_common(x, y, bc_x, bc_y) | |||
| return bprop | |||
| @bprop_getters.register(P.Mod) | |||
| def get_bprop_mod(self): | |||
| """Grad definition for `Mod` operation.""" | |||
| @@ -1027,6 +1055,22 @@ def get_bprop_atan(self): | |||
| return bprop | |||
| @bprop_getters.register(P.Tan) | |||
| def get_bprop_tan(self): | |||
| """Grad definition for `Tan` operation.""" | |||
| reciprocal = P.Reciprocal() | |||
| square = P.Square() | |||
| cos = P.Cos() | |||
| def bprop(x, out, dout): | |||
| cosx = cos(x) | |||
| secx2 = square(reciprocal(cosx)) | |||
| dx = secx2 * dout | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.BesselI1e) | |||
| def get_bprop_bessel_i1e(self): | |||
| """Generate bprop for BesselI1e""" | |||
| @@ -132,6 +132,8 @@ from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d | |||
| from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | |||
| from .apply_proximal_adagrad import _apply_proximal_adagrad | |||
| from .transpose_d import _transpose_d_tbe | |||
| from .truncate_div import _truncate_div_tbe | |||
| from .truncate_mod import _truncate_mod_tbe | |||
| from .unsorted_segment_sum import _unsorted_segment_sum_tbe | |||
| from .unsorted_segment_prod import _unsorted_segment_prod_tbe | |||
| from .logsoftmax_grad import _logsoftmax_grad_tbe | |||
| @@ -222,6 +224,7 @@ from .binary_cross_entropy import _binary_cross_entropy_tbe | |||
| from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe | |||
| from .sin import _sin_tbe | |||
| from .cos import _cos_tbe | |||
| from .tan import _tan_tbe | |||
| from .cum_sum import _cum_sum_tbe | |||
| from .apply_rms_prop import _apply_rms_prop_tbe | |||
| from .cumprod import _cumprop_tbe | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Tan op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| tan_op_info = TBERegOp("Tan") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("tan.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("tan") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("formatAgnostic") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(tan_op_info) | |||
| def _tan_tbe(): | |||
| """Tan TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """TruncateDiv op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| truncate_div_op_info = TBERegOp("TruncateDiv") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("truncate_div.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("truncate_div") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("broadcast") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ | |||
| .get_op_info() | |||
| @op_info_register(truncate_div_op_info) | |||
| def _truncate_div_tbe(): | |||
| """TruncateDiv TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """TruncateMod op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| truncate_mod_op_info = TBERegOp("TruncateMod") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("truncate_mod.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("truncate_mod") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("broadcast") \ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .input(1, "x2", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ | |||
| .get_op_info() | |||
| @op_info_register(truncate_mod_op_info) | |||
| def _truncate_mod_tbe(): | |||
| """TruncateMod TBE register""" | |||
| return | |||
| @@ -52,8 +52,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| NPUAllocFloatStatus, NPUClearFloatStatus, | |||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | |||
| Reciprocal, CumSum, HistogramFixedWidth, | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | |||
| from .random_ops import (RandomChoiceWithMask, Normal) | |||
| from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | |||
| @@ -267,6 +267,8 @@ __all__ = [ | |||
| 'SigmoidCrossEntropyWithLogits', | |||
| 'FloorDiv', | |||
| 'FloorMod', | |||
| 'TruncateDiv', | |||
| 'TruncateMod', | |||
| 'Ceil', | |||
| 'Acosh', | |||
| 'Asinh', | |||
| @@ -323,6 +325,7 @@ __all__ = [ | |||
| "BesselI1e", | |||
| "Atan", | |||
| "Atanh", | |||
| "Tan", | |||
| "BasicLSTMCell", | |||
| "BroadcastTo", | |||
| "DataFormatDimMap", | |||
| @@ -1744,6 +1744,65 @@ class FloorDiv(_MathBinaryOp): | |||
| """ | |||
| class TruncateDiv(_MathBinaryOp): | |||
| """ | |||
| Divide the first input tensor by the second input tensor element-wise for integer types, negative numbers will | |||
| round fractional quantities towards zero. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar only could be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is number or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||
| Outputs: | |||
| Tensor, the shape is same as the shape after broadcasting, | |||
| and the data type is the one with high precision or high digits among the two inputs. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) | |||
| >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) | |||
| >>> truncate_div = P.TruncateDiv() | |||
| >>> truncate_div(input_x, input_y) | |||
| [0, 1, 0] | |||
| """ | |||
| class TruncateMod(_MathBinaryOp): | |||
| """ | |||
| Returns element-wise remainder of division. | |||
| The inputs must be two tensors or one tensor and one scalar. | |||
| When the inputs are two tensors, | |||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||
| When the inputs are one tensor and one scalar, | |||
| the scalar only could be a constant. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||
| a bool or a tensor whose data type is number or bool. | |||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||
| Outputs: | |||
| Tensor, the shape is same as the shape after broadcasting, | |||
| and the data type is the one with high precision or high digits among the two inputs. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) | |||
| >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) | |||
| >>> truncate_mod = P.TruncateMod() | |||
| >>> truncate_mod(input_x, input_y) | |||
| [2, 1, -1] | |||
| """ | |||
| class Mod(_MathBinaryOp): | |||
| """ | |||
| Computes the remainder of dividing the first input tensor by the second input tensor element-wise. | |||
| @@ -2870,6 +2929,34 @@ class Round(PrimitiveWithInfer): | |||
| return x_type | |||
| class Tan(PrimitiveWithInfer): | |||
| """ | |||
| Computes tan of `input_x` element-wise. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| Outputs: | |||
| Tensor, has the same shape as `input_x`. | |||
| Examples: | |||
| >>> tan = P.Tan() | |||
| >>> input_x = Tensor(np.array([-1.0, 0.0, 1.0]), mindspore.float32) | |||
| >>> output = tan(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init Tan""" | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||
| return x_type | |||
| class Atan(PrimitiveWithInfer): | |||
| """ | |||
| Computes the trignometric inverse tangent of x element-wise. | |||
| @@ -768,6 +768,10 @@ test_case_math_ops = [ | |||
| 'block': P.Asinh(), | |||
| 'desc_inputs': [[3, 4, 5]], | |||
| 'desc_bprop': [[3, 4, 5]]}), | |||
| ('Tan', { | |||
| 'block': P.Tan(), | |||
| 'desc_inputs': [[2, 3]], | |||
| 'desc_bprop': [[2, 3]]}), | |||
| ('Reciprocal', { | |||
| 'block': P.Reciprocal(), | |||
| 'desc_inputs': [[2, 3, 3, 5]], | |||
| @@ -850,6 +854,14 @@ test_case_math_ops = [ | |||
| 'block': P.FloorMod(), | |||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('TruncateDiv', { | |||
| 'block': P.TruncateDiv(), | |||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('TruncateMod', { | |||
| 'block': P.TruncateMod(), | |||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | |||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||
| ('identity', { | |||
| 'block': ops.functional.identity, | |||
| 'desc_inputs': [[2, 2]], | |||