diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc index 157110d25c..2782775326 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -33,6 +33,11 @@ MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) + MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index 0853973872..2529cfc8ac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -84,6 +84,10 @@ class ActivationGpuFwdKernel : public GpuKernel { } std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; + if (mode_ == CUDNN_ACTIVATION_ELU) { + float alpha = GetAttr(kernel_node, "alpha"); + coef = static_cast(alpha); + } CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), "cudnnSetActivationDescriptor failed"); @@ -137,7 +141,7 @@ class ActivationGpuFwdKernel : public GpuKernel { std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, {"Tanh", CUDNN_ACTIVATION_TANH}, - {"ELU", CUDNN_ACTIVATION_ELU}, + {"Elu", CUDNN_ACTIVATION_ELU}, {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; cudnnHandle_t cudnn_handle_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc index 8e6f568031..c3ab7c1cfd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc @@ -45,6 +45,15 @@ MS_REG_GPU_KERNEL_ONE( KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + EluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + EluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) + MS_REG_GPU_KERNEL_ONE( SigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index 86709bd76a..e6ee2d56a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -91,6 +91,7 @@ class ActivationGradGpuKernel : public GpuKernel { } std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; + if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0; CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), "SetActivationDescriptor failed"); @@ -143,7 +144,7 @@ class ActivationGradGpuKernel : public GpuKernel { std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, {"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU}, {"TanhGrad", CUDNN_ACTIVATION_TANH}, - {"ELUGrad", CUDNN_ACTIVATION_ELU}, + {"EluGrad", CUDNN_ACTIVATION_ELU}, {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; cudnnHandle_t cudnn_handle_; cudnnActivationDescriptor_t activation_desc_; diff --git a/tests/st/ops/gpu/test_elu_grad_op.py b/tests/st/ops/gpu/test_elu_grad_op.py new file mode 100644 index 0000000000..8e21ed1e4f --- /dev/null +++ b/tests/st/ops/gpu/test_elu_grad_op.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetEluGrad(nn.Cell): + def __init__(self): + super(NetEluGrad, self).__init__() + self.eluGrad = G.EluGrad() + + def construct(self, x, dy): + return self.eluGrad(dy, x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_elu_grad_fp16(): + x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) + dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) + expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) + error = np.ones(shape=[2, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + elu_grad = NetEluGrad() + output = elu_grad(x, dy) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_elu_grad_fp32(): + x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float32)) + dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float32)) + expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float32) + error = np.ones(shape=[2, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + elu_grad = NetEluGrad() + output = elu_grad(x, dy) + diff = output.asnumpy() - expect + assert np.all(diff < error) diff --git a/tests/st/ops/gpu/test_elu_op.py b/tests/st/ops/gpu/test_elu_op.py new file mode 100644 index 0000000000..c0fbb792b8 --- /dev/null +++ b/tests/st/ops/gpu/test_elu_op.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class NetElu(nn.Cell): + def __init__(self): + super(NetElu, self).__init__() + self.elu = P.Elu() + + def construct(self, x): + return self.elu(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_elu_fp16(): + x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float16)) + expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float16) + error = np.ones(shape=[2, 3]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + elu = NetElu() + output = elu(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + elu = NetElu() + output = elu(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_elu_fp32(): + x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float32)) + expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float32) + error = np.ones(shape=[2, 3]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + elu = NetElu() + output = elu(x) + diff = output.asnumpy() - expect + assert np.all(diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + elu = NetElu() + output = elu(x) + diff = output.asnumpy() - expect + assert np.all(diff < error)