| @@ -306,6 +306,34 @@ def get_bprop_floormod(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.TruncateDiv) | |||||
| def get_bprop_truncate_div(self): | |||||
| """Grad definition for `TruncateDiv` operation.""" | |||||
| div_op = P.TruncateDiv() | |||||
| neg = P.Neg() | |||||
| mul_op = P.Mul() | |||||
| def bprop(x, y, out, dout): | |||||
| bc_x = div_op(dout, y) | |||||
| bc_y = neg(mul_op(bc_x, out)) | |||||
| return binop_grad_common(x, y, bc_x, bc_y) | |||||
| return bprop | |||||
| @bprop_getters.register(P.TruncateMod) | |||||
| def get_bprop_truncate_mod(self): | |||||
| """Grad definition for `TruncateMod` operation.""" | |||||
| div_op = P.TruncateDiv() | |||||
| def bprop(x, y, out, dout): | |||||
| bc_x = dout | |||||
| bc_y = -dout * div_op(x, y) | |||||
| return binop_grad_common(x, y, bc_x, bc_y) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Mod) | @bprop_getters.register(P.Mod) | ||||
| def get_bprop_mod(self): | def get_bprop_mod(self): | ||||
| """Grad definition for `Mod` operation.""" | """Grad definition for `Mod` operation.""" | ||||
| @@ -1027,6 +1055,22 @@ def get_bprop_atan(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Tan) | |||||
| def get_bprop_tan(self): | |||||
| """Grad definition for `Tan` operation.""" | |||||
| reciprocal = P.Reciprocal() | |||||
| square = P.Square() | |||||
| cos = P.Cos() | |||||
| def bprop(x, out, dout): | |||||
| cosx = cos(x) | |||||
| secx2 = square(reciprocal(cosx)) | |||||
| dx = secx2 * dout | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.BesselI1e) | @bprop_getters.register(P.BesselI1e) | ||||
| def get_bprop_bessel_i1e(self): | def get_bprop_bessel_i1e(self): | ||||
| """Generate bprop for BesselI1e""" | """Generate bprop for BesselI1e""" | ||||
| @@ -132,6 +132,8 @@ from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d | |||||
| from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad | ||||
| from .apply_proximal_adagrad import _apply_proximal_adagrad | from .apply_proximal_adagrad import _apply_proximal_adagrad | ||||
| from .transpose_d import _transpose_d_tbe | from .transpose_d import _transpose_d_tbe | ||||
| from .truncate_div import _truncate_div_tbe | |||||
| from .truncate_mod import _truncate_mod_tbe | |||||
| from .unsorted_segment_sum import _unsorted_segment_sum_tbe | from .unsorted_segment_sum import _unsorted_segment_sum_tbe | ||||
| from .unsorted_segment_prod import _unsorted_segment_prod_tbe | from .unsorted_segment_prod import _unsorted_segment_prod_tbe | ||||
| from .logsoftmax_grad import _logsoftmax_grad_tbe | from .logsoftmax_grad import _logsoftmax_grad_tbe | ||||
| @@ -222,6 +224,7 @@ from .binary_cross_entropy import _binary_cross_entropy_tbe | |||||
| from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe | from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe | ||||
| from .sin import _sin_tbe | from .sin import _sin_tbe | ||||
| from .cos import _cos_tbe | from .cos import _cos_tbe | ||||
| from .tan import _tan_tbe | |||||
| from .cum_sum import _cum_sum_tbe | from .cum_sum import _cum_sum_tbe | ||||
| from .apply_rms_prop import _apply_rms_prop_tbe | from .apply_rms_prop import _apply_rms_prop_tbe | ||||
| from .cumprod import _cumprop_tbe | from .cumprod import _cumprop_tbe | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Tan op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| tan_op_info = TBERegOp("Tan") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("tan.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("tan") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("formatAgnostic") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(tan_op_info) | |||||
| def _tan_tbe(): | |||||
| """Tan TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """TruncateDiv op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| truncate_div_op_info = TBERegOp("TruncateDiv") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("truncate_div.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("truncate_div") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .input(0, "x1", False, "required", "all") \ | |||||
| .input(1, "x2", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(truncate_div_op_info) | |||||
| def _truncate_div_tbe(): | |||||
| """TruncateDiv TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """TruncateMod op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| truncate_mod_op_info = TBERegOp("TruncateMod") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("truncate_mod.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("truncate_mod") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .input(0, "x1", False, "required", "all") \ | |||||
| .input(1, "x2", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(truncate_mod_op_info) | |||||
| def _truncate_mod_tbe(): | |||||
| """TruncateMod TBE register""" | |||||
| return | |||||
| @@ -52,8 +52,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| NPUAllocFloatStatus, NPUClearFloatStatus, | NPUAllocFloatStatus, NPUClearFloatStatus, | ||||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | ||||
| Reciprocal, CumSum, HistogramFixedWidth, | Reciprocal, CumSum, HistogramFixedWidth, | ||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps) | |||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | |||||
| from .random_ops import (RandomChoiceWithMask, Normal) | from .random_ops import (RandomChoiceWithMask, Normal) | ||||
| from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | ||||
| @@ -267,6 +267,8 @@ __all__ = [ | |||||
| 'SigmoidCrossEntropyWithLogits', | 'SigmoidCrossEntropyWithLogits', | ||||
| 'FloorDiv', | 'FloorDiv', | ||||
| 'FloorMod', | 'FloorMod', | ||||
| 'TruncateDiv', | |||||
| 'TruncateMod', | |||||
| 'Ceil', | 'Ceil', | ||||
| 'Acosh', | 'Acosh', | ||||
| 'Asinh', | 'Asinh', | ||||
| @@ -323,6 +325,7 @@ __all__ = [ | |||||
| "BesselI1e", | "BesselI1e", | ||||
| "Atan", | "Atan", | ||||
| "Atanh", | "Atanh", | ||||
| "Tan", | |||||
| "BasicLSTMCell", | "BasicLSTMCell", | ||||
| "BroadcastTo", | "BroadcastTo", | ||||
| "DataFormatDimMap", | "DataFormatDimMap", | ||||
| @@ -1744,6 +1744,65 @@ class FloorDiv(_MathBinaryOp): | |||||
| """ | """ | ||||
| class TruncateDiv(_MathBinaryOp): | |||||
| """ | |||||
| Divide the first input tensor by the second input tensor element-wise for integer types, negative numbers will | |||||
| round fractional quantities towards zero. | |||||
| The inputs must be two tensors or one tensor and one scalar. | |||||
| When the inputs are two tensors, | |||||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||||
| When the inputs are one tensor and one scalar, | |||||
| the scalar only could be a constant. | |||||
| Inputs: | |||||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||||
| a bool or a tensor whose data type is number or bool. | |||||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||||
| Outputs: | |||||
| Tensor, the shape is same as the shape after broadcasting, | |||||
| and the data type is the one with high precision or high digits among the two inputs. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) | |||||
| >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) | |||||
| >>> truncate_div = P.TruncateDiv() | |||||
| >>> truncate_div(input_x, input_y) | |||||
| [0, 1, 0] | |||||
| """ | |||||
| class TruncateMod(_MathBinaryOp): | |||||
| """ | |||||
| Returns element-wise remainder of division. | |||||
| The inputs must be two tensors or one tensor and one scalar. | |||||
| When the inputs are two tensors, | |||||
| both dtypes cannot be bool, and the shapes of them could be broadcast. | |||||
| When the inputs are one tensor and one scalar, | |||||
| the scalar only could be a constant. | |||||
| Inputs: | |||||
| - **input_x** (Union[Tensor, Number, bool]) - The first input is a number or | |||||
| a bool or a tensor whose data type is number or bool. | |||||
| - **input_y** (Union[Tensor, Number, bool]) - The second input is a number or | |||||
| a bool when the first input is a tensor or a tensor whose data type is number or bool. | |||||
| Outputs: | |||||
| Tensor, the shape is same as the shape after broadcasting, | |||||
| and the data type is the one with high precision or high digits among the two inputs. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32) | |||||
| >>> input_y = Tensor(np.array([3, 3, 3]), mindspore.int32) | |||||
| >>> truncate_mod = P.TruncateMod() | |||||
| >>> truncate_mod(input_x, input_y) | |||||
| [2, 1, -1] | |||||
| """ | |||||
| class Mod(_MathBinaryOp): | class Mod(_MathBinaryOp): | ||||
| """ | """ | ||||
| Computes the remainder of dividing the first input tensor by the second input tensor element-wise. | Computes the remainder of dividing the first input tensor by the second input tensor element-wise. | ||||
| @@ -2870,6 +2929,34 @@ class Round(PrimitiveWithInfer): | |||||
| return x_type | return x_type | ||||
| class Tan(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes tan of `input_x` element-wise. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| Outputs: | |||||
| Tensor, has the same shape as `input_x`. | |||||
| Examples: | |||||
| >>> tan = P.Tan() | |||||
| >>> input_x = Tensor(np.array([-1.0, 0.0, 1.0]), mindspore.float32) | |||||
| >>> output = tan(input_x) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Tan""" | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||||
| return x_type | |||||
| class Atan(PrimitiveWithInfer): | class Atan(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Computes the trignometric inverse tangent of x element-wise. | Computes the trignometric inverse tangent of x element-wise. | ||||
| @@ -768,6 +768,10 @@ test_case_math_ops = [ | |||||
| 'block': P.Asinh(), | 'block': P.Asinh(), | ||||
| 'desc_inputs': [[3, 4, 5]], | 'desc_inputs': [[3, 4, 5]], | ||||
| 'desc_bprop': [[3, 4, 5]]}), | 'desc_bprop': [[3, 4, 5]]}), | ||||
| ('Tan', { | |||||
| 'block': P.Tan(), | |||||
| 'desc_inputs': [[2, 3]], | |||||
| 'desc_bprop': [[2, 3]]}), | |||||
| ('Reciprocal', { | ('Reciprocal', { | ||||
| 'block': P.Reciprocal(), | 'block': P.Reciprocal(), | ||||
| 'desc_inputs': [[2, 3, 3, 5]], | 'desc_inputs': [[2, 3, 3, 5]], | ||||
| @@ -850,6 +854,14 @@ test_case_math_ops = [ | |||||
| 'block': P.FloorMod(), | 'block': P.FloorMod(), | ||||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | ||||
| 'desc_bprop': [[2, 3, 4, 5]]}), | 'desc_bprop': [[2, 3, 4, 5]]}), | ||||
| ('TruncateDiv', { | |||||
| 'block': P.TruncateDiv(), | |||||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | |||||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||||
| ('TruncateMod', { | |||||
| 'block': P.TruncateMod(), | |||||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | |||||
| 'desc_bprop': [[2, 3, 4, 5]]}), | |||||
| ('identity', { | ('identity', { | ||||
| 'block': ops.functional.identity, | 'block': ops.functional.identity, | ||||
| 'desc_inputs': [[2, 2]], | 'desc_inputs': [[2, 2]], | ||||