| @@ -33,6 +33,11 @@ MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ActivationGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ActivationGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -84,6 +84,10 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| } | |||
| std::vector<size_t> shape; | |||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; | |||
| if (mode_ == CUDNN_ACTIVATION_ELU) { | |||
| float alpha = GetAttr<float>(kernel_node, "alpha"); | |||
| coef = static_cast<double>(alpha); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), | |||
| "cudnnSetActivationDescriptor failed"); | |||
| @@ -137,7 +141,7 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, | |||
| {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, | |||
| {"Tanh", CUDNN_ACTIVATION_TANH}, | |||
| {"ELU", CUDNN_ACTIVATION_ELU}, | |||
| {"Elu", CUDNN_ACTIVATION_ELU}, | |||
| {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; | |||
| cudnnHandle_t cudnn_handle_; | |||
| @@ -45,6 +45,15 @@ MS_REG_GPU_KERNEL_ONE( | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| EluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ActivationGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| EluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SigmoidGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -91,6 +91,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| } | |||
| std::vector<size_t> shape; | |||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; | |||
| if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), | |||
| "SetActivationDescriptor failed"); | |||
| @@ -143,7 +144,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, | |||
| {"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU}, | |||
| {"TanhGrad", CUDNN_ACTIVATION_TANH}, | |||
| {"ELUGrad", CUDNN_ACTIVATION_ELU}, | |||
| {"EluGrad", CUDNN_ACTIVATION_ELU}, | |||
| {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnActivationDescriptor_t activation_desc_; | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class NetEluGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetEluGrad, self).__init__() | |||
| self.eluGrad = G.EluGrad() | |||
| def construct(self, x, dy): | |||
| return self.eluGrad(dy, x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_elu_grad_fp16(): | |||
| x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) | |||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) | |||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) | |||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| elu_grad = NetEluGrad() | |||
| output = elu_grad(x, dy) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_elu_grad_fp32(): | |||
| x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float32)) | |||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float32)) | |||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float32) | |||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| elu_grad = NetEluGrad() | |||
| output = elu_grad(x, dy) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class NetElu(nn.Cell): | |||
| def __init__(self): | |||
| super(NetElu, self).__init__() | |||
| self.elu = P.Elu() | |||
| def construct(self, x): | |||
| return self.elu(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_elu_fp16(): | |||
| x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float16)) | |||
| expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float16) | |||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| elu = NetElu() | |||
| output = elu(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| elu = NetElu() | |||
| output = elu(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_elu_fp32(): | |||
| x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float32)) | |||
| expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float32) | |||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| elu = NetElu() | |||
| output = elu(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| elu = NetElu() | |||
| output = elu(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||