From 0db45147a4e7a07e69147b0f0efa84fba79d445a Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 9 Apr 2020 21:50:31 +0800 Subject: [PATCH] dock relu6 for open source process and fix pow bprop --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 2 ++ mindspore/ops/_grad/grad_math_ops.py | 6 ++-- mindspore/ops/_op_impl/tbe/__init__.py | 2 ++ mindspore/ops/_op_impl/tbe/relu6.py | 40 +++++++++++++++++++++ mindspore/ops/_op_impl/tbe/relu6_grad.py | 43 +++++++++++++++++++++++ 5 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/relu6.py create mode 100644 mindspore/ops/_op_impl/tbe/relu6_grad.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 229a3eb34a..5336d1e67f 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -30,6 +30,8 @@ namespace mindspore { namespace kernel { namespace tbe { static std::map tbe_func_adapter_map = { + {"re_lu6", "relu6"}, + {"re_lu6_grad", "relu6_grad"}, {"re_lu", "relu"}, {"tensor_add", "add"}, {"reduce_mean", "reduce_mean_d"}, diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 81e078dc98..2d819718c8 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -340,9 +340,9 @@ def get_bprop_pow(self): ln = P.Log() def bprop(x, power, out, dout): - dx = power * pow_op(x, power - 1.0) * dout - dpower = pow_op(x, power) * ln(x) * dout - return dx, dpower + bc_dx = power * pow_op(x, power - 1.0) * dout + bc_dpower = out * ln(x) * dout + return binop_grad_common(x, power, bc_dx, bc_dpower) return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 0b79ae845b..9ec5446165 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -42,6 +42,8 @@ from .mul import _mul_tbe from .real_div import _real_div_tbe from .relu import _relu_tbe from .relu_grad import _relu_grad_tbe +from .relu6 import _relu6_tbe +from .relu6_grad import _relu6_grad_tbe from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/relu6.py b/mindspore/ops/_op_impl/tbe/relu6.py new file mode 100644 index 0000000000..bbedfdeb0f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu6.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReLU6 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu6_op_info = TBERegOp("ReLU6") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu6.so") \ + .compute_cost(10) \ + .kernel_name("relu6") \ + .partial_flag(True) \ + .input(0, "features", False, "required", "all") \ + .output(0, "activations", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .get_op_info() + + +@op_info_register(relu6_op_info) +def _relu6_tbe(): + """Relu6 TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/relu6_grad.py b/mindspore/ops/_op_impl/tbe/relu6_grad.py new file mode 100644 index 0000000000..eaf3449fe7 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu6_grad.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReLU6Grad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu6_grad_op_info = TBERegOp("ReLU6Grad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu6_grad.so") \ + .compute_cost(10) \ + .kernel_name("relu6_grad") \ + .partial_flag(True) \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "features", False, "required", "all") \ + .output(0, "backprops", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(relu6_grad_op_info) +def _relu6_grad_tbe(): + """Relu6Grad TBE register""" + return