Merge pull request !559 from liuxiao/fill-optags/v0.2.0-alpha
| @@ -57,6 +57,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"strided_slice", "strided_slice_d"}, | {"strided_slice", "strided_slice_d"}, | ||||
| {"strided_slice_grad", "strided_slice_grad_d"}, | {"strided_slice_grad", "strided_slice_grad_d"}, | ||||
| {"transpose", "transpose_d"}, | {"transpose", "transpose_d"}, | ||||
| {"fill", "fill_d"}, | |||||
| {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | {"unsorted_segment_sum", "unsorted_segment_sum_d"}, | ||||
| {"concat", "concat_d"}, | {"concat", "concat_d"}, | ||||
| {"slice", "slice_d"}, | {"slice", "slice_d"}, | ||||
| @@ -53,6 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||||
| Register(kExpandDimsOpName, {1}); | Register(kExpandDimsOpName, {1}); | ||||
| Register(kSplitOpName, {0}); | Register(kSplitOpName, {0}); | ||||
| Register(kTopKOpName, {1}); | Register(kTopKOpName, {1}); | ||||
| Register(kErfOpName, {1}); | |||||
| Register(kSparseApplyAdagradOpName, {2}); | Register(kSparseApplyAdagradOpName, {2}); | ||||
| Register(kResizeNearestNeighborGrad, {1}); | Register(kResizeNearestNeighborGrad, {1}); | ||||
| } | } | ||||
| @@ -92,6 +92,7 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; | |||||
| constexpr auto kGreaterOpName = "Greater"; | constexpr auto kGreaterOpName = "Greater"; | ||||
| constexpr auto kSqrtOpName = "Sqrt"; | constexpr auto kSqrtOpName = "Sqrt"; | ||||
| constexpr auto kRsqrtOpName = "Rsqrt"; | constexpr auto kRsqrtOpName = "Rsqrt"; | ||||
| constexpr auto kErfOpName = "Erf"; | |||||
| constexpr auto kRealDivOpName = "RealDiv"; | constexpr auto kRealDivOpName = "RealDiv"; | ||||
| constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; | constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; | ||||
| constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; | constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; | ||||
| @@ -17,6 +17,7 @@ | |||||
| from functools import reduce | from functools import reduce | ||||
| import numpy as np | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| @@ -333,6 +334,23 @@ def get_bprop_log(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Erf) | |||||
| def get_bprop_erf(self): | |||||
| """Grad definition for `Erf` operation.""" | |||||
| exp = P.Exp() | |||||
| square = P.Square() | |||||
| sqrt = P.Sqrt() | |||||
| cast = P.Cast() | |||||
| dtype = P.DType() | |||||
| def bprop(x, out, dout): | |||||
| half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x)) | |||||
| x_square = square(x) | |||||
| dx = dout * half_root_pi * exp(-x_square) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Pow) | @bprop_getters.register(P.Pow) | ||||
| def get_bprop_pow(self): | def get_bprop_pow(self): | ||||
| """Grad definition for `Pow` operation.""" | """Grad definition for `Pow` operation.""" | ||||
| @@ -139,6 +139,8 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe | |||||
| from .fused_mul_add import _fused_mul_add_tbe | from .fused_mul_add import _fused_mul_add_tbe | ||||
| from .fused_mul_add_n import _fused_mul_add_n_tbe | from .fused_mul_add_n import _fused_mul_add_n_tbe | ||||
| from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe | from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe | ||||
| from .fill_d import _fill_d_op_tbe | |||||
| from .erf import _erf_op_tbe | |||||
| from .depthwise_conv2d import _depthwise_conv2d_tbe | from .depthwise_conv2d import _depthwise_conv2d_tbe | ||||
| from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe | from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe | ||||
| from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe | from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Erf op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| erf_op_info = TBERegOp("Erf") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("erf.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("erf") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("formatAgnostic") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(erf_op_info) | |||||
| def _erf_op_tbe(): | |||||
| """Erf TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """FillD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| fill_d_op_info = TBERegOp("FillD") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("fill_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("fill_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("dims", "required", "listInt", "all") \ | |||||
| .input(0, "value", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ | |||||
| .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ | |||||
| .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ | |||||
| .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \ | |||||
| .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ | |||||
| .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \ | |||||
| .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(fill_d_op_info) | |||||
| def _fill_d_op_tbe(): | |||||
| """FillD TBE register""" | |||||
| return | |||||
| @@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge | |||||
| from .inner_ops import ScalarCast | from .inner_ops import ScalarCast | ||||
| from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, | from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, | ||||
| ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, | ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, | ||||
| Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv, FloorMod, Acosh, | |||||
| Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, | |||||
| Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd, | Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd, | ||||
| LogicalNot, LogicalOr, MatMul, Maximum, | LogicalNot, LogicalOr, MatMul, Maximum, | ||||
| Minimum, Mul, Neg, NMSWithMask, NotEqual, | Minimum, Mul, Neg, NMSWithMask, NotEqual, | ||||
| @@ -139,6 +139,7 @@ __all__ = [ | |||||
| 'ReLU', | 'ReLU', | ||||
| 'ReLU6', | 'ReLU6', | ||||
| 'Elu', | 'Elu', | ||||
| 'Erf', | |||||
| 'Sigmoid', | 'Sigmoid', | ||||
| 'HSwish', | 'HSwish', | ||||
| 'HSigmoid', | 'HSigmoid', | ||||
| @@ -1007,6 +1007,36 @@ class Log(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| class Erf(PrimitiveWithInfer): | |||||
| r""" | |||||
| Computes the Gauss error function of `input_x` element-wise. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as the `input_x`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32) | |||||
| >>> erf = P.Erf() | |||||
| >>> erf(input_x) | |||||
| [-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Erf""" | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||||
| return x_type | |||||
| class Minimum(_MathBinaryOp): | class Minimum(_MathBinaryOp): | ||||
| """ | """ | ||||
| Computes the element-wise minimum of input tensors. | Computes the element-wise minimum of input tensors. | ||||
| @@ -250,6 +250,10 @@ test_case_math_ops = [ | |||||
| 'block': P.Exp(), | 'block': P.Exp(), | ||||
| 'desc_inputs': [[2, 3]], | 'desc_inputs': [[2, 3]], | ||||
| 'desc_bprop': [[2, 3]]}), | 'desc_bprop': [[2, 3]]}), | ||||
| ('Erf', { | |||||
| 'block': P.Erf(), | |||||
| 'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))], | |||||
| 'desc_bprop': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))]}), | |||||
| ('Floor', { | ('Floor', { | ||||
| 'block': P.Floor(), | 'block': P.Floor(), | ||||
| 'desc_inputs': [[2, 512, 56, 56]], | 'desc_inputs': [[2, 512, 56, 56]], | ||||