Merge pull request !2150 from chenweifeng/tanh-fp16tags/v0.5.0-beta
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/cuda_impl/tanh_impl.cuh" | |||||
| #include <cuda_runtime.h> | |||||
| template<typename T> | |||||
| __global__ void TanhKernel(const size_t size, const T* x_addr, T* y_addr) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| y_addr[pos] = tanh(x_addr[pos]); | |||||
| } | |||||
| } | |||||
| template<typename T> | |||||
| __global__ void TanhGradKernel(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr) { | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| dx_addr[pos] = dy_addr[pos] * (1 - y_addr[pos] * y_addr[pos]); | |||||
| } | |||||
| } | |||||
| template<typename T> | |||||
| void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream) { | |||||
| TanhKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x_addr, y_addr); | |||||
| } | |||||
| template<typename T> | |||||
| void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream) { | |||||
| TanhGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, y_addr, dy_addr, dx_addr); | |||||
| } | |||||
| template void Tanh(const size_t size, const float* x_addr, float* y_addr, cudaStream_t cuda_stream); | |||||
| template void TanhGrad(const size_t size, const float* y_addr, const float* dy_addr, | |||||
| float* dx_addr, cudaStream_t cuda_stream); | |||||
| @@ -1,28 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template<typename T> | |||||
| void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream); | |||||
| template<typename T> | |||||
| void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ | |||||
| @@ -14,13 +14,18 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/nn/relu_gpu_kernel.h" | |||||
| #include "kernel/gpu/nn/activation_gpu_kernel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ReLUGpuFwdKernel, float) | |||||
| ActivationGpuFwdKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ReLUGpuFwdKernel, half) | |||||
| ActivationGpuFwdKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ActivationGpuFwdKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ActivationGpuFwdKernel, half) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,6 +18,8 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include "kernel/gpu/gpu_kernel.h" | #include "kernel/gpu/gpu_kernel.h" | ||||
| #include "kernel/gpu/gpu_kernel_factory.h" | #include "kernel/gpu/gpu_kernel_factory.h" | ||||
| #include "kernel/gpu/kernel_constants.h" | #include "kernel/gpu/kernel_constants.h" | ||||
| @@ -25,9 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| class ReLUGpuFwdKernel : public GpuKernel { | |||||
| class ActivationGpuFwdKernel : public GpuKernel { | |||||
| public: | public: | ||||
| ReLUGpuFwdKernel() | |||||
| ActivationGpuFwdKernel() | |||||
| : cudnn_handle_(nullptr), | : cudnn_handle_(nullptr), | ||||
| activation_desc_(nullptr), | activation_desc_(nullptr), | ||||
| mode_(CUDNN_ACTIVATION_RELU), | mode_(CUDNN_ACTIVATION_RELU), | ||||
| @@ -37,7 +39,7 @@ class ReLUGpuFwdKernel : public GpuKernel { | |||||
| input_size_(0), | input_size_(0), | ||||
| output_size_(0), | output_size_(0), | ||||
| workspace_size_(0) {} | workspace_size_(0) {} | ||||
| ~ReLUGpuFwdKernel() override { DestroyResource(); } | |||||
| ~ActivationGpuFwdKernel() override { DestroyResource(); } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| @@ -54,33 +56,39 @@ class ReLUGpuFwdKernel : public GpuKernel { | |||||
| const float beta = 0; | const float beta = 0; | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, | ||||
| &beta, data_descriptor_, output), | &beta, data_descriptor_, output), | ||||
| "ReLUGpuFwdKernel failed"); | |||||
| "cudnnActivationForward failed"); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto iter = kernel_map.find(node_name); | |||||
| if (iter == kernel_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; | |||||
| } | |||||
| mode_ = iter->second; | |||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; | |||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | is_null_input_ = CHECK_NULL_INPUT(input_shape); | ||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null."; | |||||
| MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| mode_ = CUDNN_ACTIVATION_RELU; | |||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| ShapeNdTo4d(input_shape, &shape); | ShapeNdTo4d(input_shape, &shape); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), | ||||
| "SetActivationDescriptor failed"); | |||||
| "cudnnSetActivationDescriptor failed"); | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | ||||
| shape[0], shape[1], shape[2], shape[3]), | shape[0], shape[1], shape[2], shape[3]), | ||||
| "SetTensor4dDescriptor failed"); | |||||
| "cudnnSetTensor4dDescriptor failed"); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -110,6 +118,11 @@ class ReLUGpuFwdKernel : public GpuKernel { | |||||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); | CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); | ||||
| } | } | ||||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, | |||||
| {"Tanh", CUDNN_ACTIVATION_TANH}, | |||||
| {"ELU", CUDNN_ACTIVATION_ELU}, | |||||
| {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; | |||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| cudnnActivationDescriptor_t activation_desc_; | cudnnActivationDescriptor_t activation_desc_; | ||||
| cudnnActivationMode_t mode_; | cudnnActivationMode_t mode_; | ||||
| @@ -14,17 +14,26 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/nn/relu_grad_kernel.h" | |||||
| #include "kernel/gpu/nn/activation_grad_kernel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| ReluGrad, | ReluGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ReluGradGpuKernel, float) | |||||
| ActivationGradGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | MS_REG_GPU_KERNEL_ONE( | ||||
| ReluGrad, | ReluGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| ReluGradGpuKernel, half) | |||||
| ActivationGradGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| TanhGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ActivationGradGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| TanhGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ActivationGradGpuKernel, half) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,6 +18,8 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include "kernel/gpu/gpu_kernel.h" | #include "kernel/gpu/gpu_kernel.h" | ||||
| #include "kernel/gpu/gpu_kernel_factory.h" | #include "kernel/gpu/gpu_kernel_factory.h" | ||||
| #include "kernel/gpu/kernel_constants.h" | #include "kernel/gpu/kernel_constants.h" | ||||
| @@ -25,9 +27,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| class ReluGradGpuKernel : public GpuKernel { | |||||
| class ActivationGradGpuKernel : public GpuKernel { | |||||
| public: | public: | ||||
| ReluGradGpuKernel() | |||||
| ActivationGradGpuKernel() | |||||
| : cudnn_handle_(nullptr), | : cudnn_handle_(nullptr), | ||||
| activation_desc_(nullptr), | activation_desc_(nullptr), | ||||
| mode_(CUDNN_ACTIVATION_RELU), | mode_(CUDNN_ACTIVATION_RELU), | ||||
| @@ -35,7 +37,7 @@ class ReluGradGpuKernel : public GpuKernel { | |||||
| is_null_input_(false), | is_null_input_(false), | ||||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | cudnn_data_type_(CUDNN_DATA_FLOAT), | ||||
| input_size_(0) {} | input_size_(0) {} | ||||
| ~ReluGradGpuKernel() override { DestroyResource(); } | |||||
| ~ActivationGradGpuKernel() override { DestroyResource(); } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| @@ -45,8 +47,15 @@ class ReluGradGpuKernel : public GpuKernel { | |||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| T *y = GetDeviceAddress<T>(inputs, 1); | |||||
| T *dy = GetDeviceAddress<T>(inputs, 0); | |||||
| T *dy = nullptr; | |||||
| T *y = nullptr; | |||||
| if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { | |||||
| dy = GetDeviceAddress<T>(inputs, 0); | |||||
| y = GetDeviceAddress<T>(inputs, 1); | |||||
| } else { | |||||
| y = GetDeviceAddress<T>(inputs, 0); | |||||
| dy = GetDeviceAddress<T>(inputs, 1); | |||||
| } | |||||
| T *dx = GetDeviceAddress<T>(outputs, 0); | T *dx = GetDeviceAddress<T>(outputs, 0); | ||||
| const float alpha = 1; | const float alpha = 1; | ||||
| @@ -59,18 +68,24 @@ class ReluGradGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| auto iter = kernel_map.find(node_name); | |||||
| if (iter == kernel_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; | |||||
| } | |||||
| mode_ = iter->second; | |||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 2) { | if (input_num != 2) { | ||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuKernel needs 2."; | |||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| mode_ = CUDNN_ACTIVATION_RELU; | |||||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | is_null_input_ = CHECK_NULL_INPUT(input_shape); | ||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| MS_LOG(WARNING) << "ReluGradGpuKernel input is null."; | |||||
| MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -110,6 +125,10 @@ class ReluGradGpuKernel : public GpuKernel { | |||||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); | CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); | ||||
| } | } | ||||
| std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, | |||||
| {"TanhGrad", CUDNN_ACTIVATION_TANH}, | |||||
| {"ELUGrad", CUDNN_ACTIVATION_ELU}, | |||||
| {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; | |||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| cudnnActivationDescriptor_t activation_desc_; | cudnnActivationDescriptor_t activation_desc_; | ||||
| cudnnActivationMode_t mode_; | cudnnActivationMode_t mode_; | ||||
| @@ -1,24 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/tanh_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| TanhGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,75 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_ | |||||
| #include <cuda_runtime_api.h> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/tanh_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class TanhGpuKernel : public GpuKernel { | |||||
| public: | |||||
| TanhGpuKernel() : input_size_(0) {} | |||||
| ~TanhGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto x_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| auto y_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| Tanh(input_size_ / sizeof(T), x_addr, y_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| input_size_ = sizeof(T); | |||||
| for (auto dim : input_shape) { | |||||
| input_size_ *= dim; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(input_size_); | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ | |||||
| @@ -1,26 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/tanh_grad_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| TanhGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| TanhGradKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,76 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ | |||||
| #include <cuda_runtime_api.h> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/tanh_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class TanhGradKernel : public GpuKernel { | |||||
| public: | |||||
| TanhGradKernel() : input_size_(0) {} | |||||
| ~TanhGradKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto y_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| auto dy_addr = GetDeviceAddress<T>(inputs, 1); | |||||
| auto dx_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| TanhGrad(input_size_ / sizeof(T), y_addr, dy_addr, dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| input_size_ = sizeof(T); | |||||
| for (auto dim : input_shape) { | |||||
| input_size_ *= dim; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(input_size_); | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ | |||||
| @@ -72,3 +72,40 @@ def test_Tanh(): | |||||
| [1.78391056, 0.44159236, 0.33690308, 0.16800483, -0.13651318, -0.63878956, 0.18175511, 0.65280384]] | [1.78391056, 0.44159236, 0.33690308, 0.16800483, -0.13651318, -0.63878956, 0.18175511, 0.65280384]] | ||||
| assert np.allclose(output[0].asnumpy(), expect) | assert np.allclose(output[0].asnumpy(), expect) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_Tanh_fp16(): | |||||
| np.random.seed(42) | |||||
| x_np = np.random.randn(5, 3, 6).astype(np.float16) | |||||
| dy_np = np.random.randn(5, 3, 6).astype(np.float16) | |||||
| x_ms = Tensor(x_np) | |||||
| dy_ms = Tensor(dy_np) | |||||
| net = TanhNet() | |||||
| grad = Grad(net) | |||||
| output = grad(x_ms, dy_ms) | |||||
| expect = [[[0.0766, 0.95, -0.474, -0.0568, -0.3713, -1.387], | |||||
| [0.04626, 0.1521, 0.004135, -0.1771, -1.149, -0.341], | |||||
| [-0.3235, -0.0666, -0.01921, 0.299, 0.7764, 0.1583]], | |||||
| [[0.124, -0.0157, -0.3682, -0.0252, 0.05997, 0.51], | |||||
| [-0.145, 0.2979, -0.01145, -1.019, 0.8125, 0.6914], | |||||
| [0.562, -0.0848, 1.402, -0.5386, 0.318, 0.645]], | |||||
| [[-0.9487, -0.04343, 0.02448, -0.4844, -0.939, 0.0666], | |||||
| [-1.049, 0.433, -0.1724, 0.9604, -0.6377, -0.1241], | |||||
| [0.7246, -0.1364, 0.2051, 1.132, -1.049, 0.1298]], | |||||
| [[0.104, 0.3643, -0.6562, -1.202, 0.4688, 0.1294], | |||||
| [0.2008, 0.3347, -0.2418, 0.07135, 0.1611, -0.1667], | |||||
| [1.856, 0.1979, -1.048, 0.4443, -0.8574, 0.1329]], | |||||
| [[1.156, -0.1322, 0.02069, 0.2241, 0.8164, 1.736], | |||||
| [-0.2433, -0.05484, -0.848, -0.7197, -0.01453, 0.2637], | |||||
| [0.1528, 0.6494, 0.006195, 1.307, -0.2024, 2.113]]] | |||||
| assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3, atol=1e-3) | |||||