From: @yuan_shen_zhou Reviewed-by: @wilfchen,@liangchenghui,@linqingke Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) { | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { | |||
| CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx); | |||
| return; | |||
| } | |||
| template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); | |||
| template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); | |||
| template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream); | |||
| template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); | |||
| template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); | |||
| @@ -1,23 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ | |||
| @@ -31,11 +31,14 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) | |||
| return; | |||
| } | |||
| template void CalReLU(int size, double *input_addr, double *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | |||
| template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| __global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { | |||
| @@ -69,14 +72,26 @@ void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cuda | |||
| ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx); | |||
| } | |||
| template void ReluV2(const size_t num, const double *x, double *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx, | |||
| cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const int8_t *dy, const uint32_t *mask, int8_t *dx, | |||
| cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const int16_t *dy, const uint32_t *mask, int16_t *dx, | |||
| cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, | |||
| cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, | |||
| cudaStream_t cuda_stream); | |||
| @@ -46,7 +46,8 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F; | |||
| static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = { | |||
| {"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, | |||
| {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}, | |||
| {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}}; | |||
| {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}, | |||
| {"kNumberTypeUInt8", CUDNN_DATA_UINT8}}; | |||
| // Used by mixprecision, cuda dtype select | |||
| static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, | |||
| {"kNumberTypeFloat16", CUDA_R_16F}}; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,15 +18,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ActivationGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||
| ActivationGpuFwdKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ActivationGpuFwdKernel, int32_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ActivationGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <map> | |||
| @@ -44,17 +44,12 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| if (mode_ == CUDNN_ACTIVATION_RELU) { | |||
| const int size = input_size_ / sizeof(T); | |||
| CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, | |||
| input, &beta, data_descriptor_, output), | |||
| "cudnnActivationForward failed"); | |||
| } | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, | |||
| &beta, data_descriptor_, output), | |||
| "cudnnActivationForward failed"); | |||
| return true; | |||
| } | |||
| @@ -125,7 +120,7 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| void ResetResource() noexcept override { | |||
| cudnn_handle_ = nullptr; | |||
| activation_desc_ = nullptr; | |||
| mode_ = CUDNN_ACTIVATION_RELU; | |||
| mode_ = CUDNN_ACTIVATION_SIGMOID; | |||
| data_descriptor_ = nullptr; | |||
| is_null_input_ = false; | |||
| input_size_list_.clear(); | |||
| @@ -154,11 +149,11 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| } | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| private: | |||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, | |||
| {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, | |||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, | |||
| {"Tanh", CUDNN_ACTIVATION_TANH}, | |||
| {"Elu", CUDNN_ACTIVATION_ELU}, | |||
| {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; | |||
| @@ -179,4 +174,4 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,6 +18,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| ActivationGradGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -26,12 +30,21 @@ MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ActivationGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| ActivationGradGpuKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ActivationGradGpuKernel, int32_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| ActivationGradGpuKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||
| ActivationGradGpuKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| ActivationGradGpuKernel, uint8_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLU6Grad, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ | |||
| #include <vector> | |||
| #include <map> | |||
| @@ -23,7 +23,6 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -52,18 +51,13 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| } | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| if (mode_ == CUDNN_ACTIVATION_RELU) { | |||
| const int size = input_size_ / sizeof(T); | |||
| CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, | |||
| data_descriptor_, y, &beta, data_descriptor_, dx), | |||
| "cudnnActivationBackward failed"); | |||
| } | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, | |||
| data_descriptor_, y, &beta, data_descriptor_, dx), | |||
| "cudnnActivationBackward failed"); | |||
| return true; | |||
| } | |||
| @@ -179,4 +173,4 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| ReLUGpuFwdKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ReLUGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ReLUGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| ReLUGpuFwdKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ReLUGpuFwdKernel, int32_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| ReLUGpuFwdKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel, | |||
| int8_t) | |||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| ReLUGpuFwdKernel, uint8_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <map> | |||
| #include <string> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ReLUGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| ReLUGpuFwdKernel() { ResetResource(); } | |||
| ~ReLUGpuFwdKernel() override {} | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| const int size = input_size_ / sizeof(T); | |||
| CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null."; | |||
| } | |||
| size_t size = 1; | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| size *= input_shape[i]; | |||
| } | |||
| input_size_ = size * sizeof(T); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| is_null_input_ = false; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| input_size_ = 0; | |||
| workspace_size_ = 0; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(input_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| private: | |||
| bool is_null_input_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t input_size_; | |||
| size_t workspace_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ | |||
| @@ -18,6 +18,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| ReluGradV2GpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -26,6 +30,13 @@ MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| ReluGradV2GpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), | |||
| ReluGradV2GpuKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), | |||
| ReluGradV2GpuKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), | |||
| @@ -34,5 +45,10 @@ MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), | |||
| ReluGradV2GpuKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReluGradV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), | |||
| ReluGradV2GpuKernel, uint8_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), | |||
| @@ -26,12 +30,20 @@ MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, int32_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), | |||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), | |||
| ReluV2GpuKernel, uint8_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -79,4 +79,4 @@ class ReluV2GpuKernel : public GpuKernel { | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ | |||
| @@ -1,84 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class NetReluGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetReluGrad, self).__init__() | |||
| self.rekuGrad = G.ReluGrad() | |||
| def construct(self, x, dy): | |||
| return self.rekuGrad(dy, x) | |||
| def relu_grad_base(dtype): | |||
| x = Tensor(np.array([[[[-1, 1, 1], | |||
| [1, -1, 1], | |||
| [1, 1, -1]]]]).astype(dtype)) | |||
| dy = Tensor(np.array([[[[1, 0, 1], | |||
| [0, 1, 0], | |||
| [1, 1, 1]]]]).astype(dtype)) | |||
| expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype) | |||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| relu_grad = NetReluGrad() | |||
| output = relu_grad(x, dy) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| assert output.asnumpy().dtype == dtype | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_relu_grad_float16(): | |||
| relu_grad_base(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_relu_grad_float32(): | |||
| relu_grad_base(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_relu_grad_int8(): | |||
| relu_grad_base(np.int8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_relu_grad_int32(): | |||
| relu_grad_base(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_relu_grad_int64(): | |||
| relu_grad_base(np.int64) | |||