| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| __global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { | |||||
| CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx); | |||||
| return; | |||||
| } | |||||
| template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); | |||||
| template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); | |||||
| template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream); | |||||
| template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); | |||||
| template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,23 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||||
| @@ -33,6 +33,7 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) | |||||
| template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream); | |||||
| template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | ||||
| @@ -22,6 +22,8 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||||
| ActivationGpuFwdKernel, float) | ActivationGpuFwdKernel, float) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ActivationGpuFwdKernel, half) | ActivationGpuFwdKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||||
| ActivationGpuFwdKernel, int8_t) | |||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ActivationGpuFwdKernel, int32_t) | ActivationGpuFwdKernel, int32_t) | ||||
| @@ -26,6 +26,12 @@ MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGrad, | ReluGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ActivationGradGpuKernel, half) | ActivationGradGpuKernel, half) | ||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ActivationGradGpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||||
| ActivationGradGpuKernel, int8_t) | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| ReLU6Grad, | ReLU6Grad, | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | #include "backend/kernel_compiler/gpu/kernel_constants.h" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -36,7 +37,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | ||||
| const std::vector<AddressPtr> &outputs, void *) override { | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -51,13 +52,18 @@ class ActivationGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| T *dx = GetDeviceAddress<T>(outputs, 0); | T *dx = GetDeviceAddress<T>(outputs, 0); | ||||
| const float alpha = 1; | |||||
| const float beta = 0; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, | |||||
| data_descriptor_, y, &beta, data_descriptor_, dx), | |||||
| "cudnnActivationBackward failed"); | |||||
| if (mode_ == CUDNN_ACTIVATION_RELU) { | |||||
| const int size = input_size_ / sizeof(T); | |||||
| CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } else { | |||||
| const float alpha = 1; | |||||
| const float beta = 0; | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, | |||||
| data_descriptor_, y, &beta, data_descriptor_, dx), | |||||
| "cudnnActivationBackward failed"); | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -31,17 +31,14 @@ class NetReluGrad(nn.Cell): | |||||
| return self.rekuGrad(dy, x) | return self.rekuGrad(dy, x) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad(): | |||||
| def relu_grad_base(dtype): | |||||
| x = Tensor(np.array([[[[-1, 1, 1], | x = Tensor(np.array([[[[-1, 1, 1], | ||||
| [1, -1, 1], | [1, -1, 1], | ||||
| [1, 1, -1]]]]).astype(np.float32)) | |||||
| [1, 1, -1]]]]).astype(dtype)) | |||||
| dy = Tensor(np.array([[[[1, 0, 1], | dy = Tensor(np.array([[[[1, 0, 1], | ||||
| [0, 1, 0], | [0, 1, 0], | ||||
| [1, 1, 1]]]]).astype(np.float32)) | |||||
| expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.float32) | |||||
| [1, 1, 1]]]]).astype(dtype)) | |||||
| expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype) | |||||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | error = np.ones(shape=[3, 3]) * 1.0e-6 | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| @@ -49,3 +46,39 @@ def test_relu_grad(): | |||||
| output = relu_grad(x, dy) | output = relu_grad(x, dy) | ||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| assert output.asnumpy().dtype == dtype | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_float16(): | |||||
| relu_grad_base(np.float16) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_float32(): | |||||
| relu_grad_base(np.float32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_int8(): | |||||
| relu_grad_base(np.int8) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_int32(): | |||||
| relu_grad_base(np.int32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_int64(): | |||||
| relu_grad_base(np.int64) | |||||
| @@ -65,6 +65,28 @@ def test_relu_float32(): | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_int8(): | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.int8)) | |||||
| expect = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.int8) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| relu = NetRelu() | |||||
| output = relu(x) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||