| @@ -33,6 +33,11 @@ MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||||
| MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ActivationGpuFwdKernel, half) | ActivationGpuFwdKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ActivationGpuFwdKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Elu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ActivationGpuFwdKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ActivationGpuFwdKernel, float) | ActivationGpuFwdKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| @@ -84,6 +84,10 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| std::vector<size_t> shape; | std::vector<size_t> shape; | ||||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; | double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; | ||||
| if (mode_ == CUDNN_ACTIVATION_ELU) { | |||||
| float alpha = GetAttr<float>(kernel_node, "alpha"); | |||||
| coef = static_cast<double>(alpha); | |||||
| } | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), | ||||
| "cudnnSetActivationDescriptor failed"); | "cudnnSetActivationDescriptor failed"); | ||||
| @@ -137,7 +141,7 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, | std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, | ||||
| {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, | {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, | ||||
| {"Tanh", CUDNN_ACTIVATION_TANH}, | {"Tanh", CUDNN_ACTIVATION_TANH}, | ||||
| {"ELU", CUDNN_ACTIVATION_ELU}, | |||||
| {"Elu", CUDNN_ACTIVATION_ELU}, | |||||
| {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; | {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; | ||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| @@ -45,6 +45,15 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ActivationGradGpuKernel, half) | ActivationGradGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| EluGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ActivationGradGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| EluGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ActivationGradGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| SigmoidGrad, | SigmoidGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| @@ -91,6 +91,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| std::vector<size_t> shape; | std::vector<size_t> shape; | ||||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; | double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; | ||||
| if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), | ||||
| "SetActivationDescriptor failed"); | "SetActivationDescriptor failed"); | ||||
| @@ -143,7 +144,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, | std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, | ||||
| {"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU}, | {"ReLU6Grad", CUDNN_ACTIVATION_CLIPPED_RELU}, | ||||
| {"TanhGrad", CUDNN_ACTIVATION_TANH}, | {"TanhGrad", CUDNN_ACTIVATION_TANH}, | ||||
| {"ELUGrad", CUDNN_ACTIVATION_ELU}, | |||||
| {"EluGrad", CUDNN_ACTIVATION_ELU}, | |||||
| {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; | {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; | ||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| cudnnActivationDescriptor_t activation_desc_; | cudnnActivationDescriptor_t activation_desc_; | ||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| class NetEluGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetEluGrad, self).__init__() | |||||
| self.eluGrad = G.EluGrad() | |||||
| def construct(self, x, dy): | |||||
| return self.eluGrad(dy, x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_grad_fp16(): | |||||
| x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) | |||||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) | |||||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) | |||||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| elu_grad = NetEluGrad() | |||||
| output = elu_grad(x, dy) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_grad_fp32(): | |||||
| x = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float32)) | |||||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float32)) | |||||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float32) | |||||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| elu_grad = NetEluGrad() | |||||
| output = elu_grad(x, dy) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class NetElu(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetElu, self).__init__() | |||||
| self.elu = P.Elu() | |||||
| def construct(self, x): | |||||
| return self.elu(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_fp16(): | |||||
| x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float16)) | |||||
| expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float16) | |||||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| elu = NetElu() | |||||
| output = elu(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| elu = NetElu() | |||||
| output = elu(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_fp32(): | |||||
| x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]).astype(np.float32)) | |||||
| expect = np.array([[-0.632, 4.0, -0.999], [2.0, -0.993, 9.0]]).astype(np.float32) | |||||
| error = np.ones(shape=[2, 3]) * 1.0e-6 | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| elu = NetElu() | |||||
| output = elu(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| elu = NetElu() | |||||
| output = elu(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||