diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cu deleted file mode 100644 index 5471ffb5d9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cu +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/tanh_impl.cuh" -#include - -template -__global__ void TanhKernel(const size_t size, const T* x_addr, T* y_addr) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - y_addr[pos] = tanh(x_addr[pos]); - } -} - -template -__global__ void TanhGradKernel(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - dx_addr[pos] = dy_addr[pos] * (1 - y_addr[pos] * y_addr[pos]); - } -} - -template -void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream) { - TanhKernel<<>>(size, x_addr, y_addr); -} - -template -void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream) { - TanhGradKernel<<>>(size, y_addr, dy_addr, dx_addr); -} - -template void Tanh(const size_t size, const float* x_addr, float* y_addr, cudaStream_t cuda_stream); -template void TanhGrad(const size_t size, const float* y_addr, const float* dy_addr, - float* dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cuh deleted file mode 100644 index 71fc4be4dd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cuh +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ - -#include "device/gpu/cuda_common.h" - -template -void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream); - -template -void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc similarity index 65% rename from mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.cc rename to mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc index d4cefc73ca..0246707d16 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc @@ -14,13 +14,18 @@ * limitations under the License. */ -#include "kernel/gpu/nn/relu_gpu_kernel.h" +#include "kernel/gpu/nn/activation_gpu_kernel.h" namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReLUGpuFwdKernel, float) + ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ReLUGpuFwdKernel, half) + ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h similarity index 81% rename from mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h rename to mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h index bcc8819fde..bf6cfa7b23 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h @@ -18,6 +18,8 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ #include +#include +#include #include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/kernel_constants.h" @@ -25,9 +27,9 @@ namespace mindspore { namespace kernel { template -class ReLUGpuFwdKernel : public GpuKernel { +class ActivationGpuFwdKernel : public GpuKernel { public: - ReLUGpuFwdKernel() + ActivationGpuFwdKernel() : cudnn_handle_(nullptr), activation_desc_(nullptr), mode_(CUDNN_ACTIVATION_RELU), @@ -37,7 +39,7 @@ class ReLUGpuFwdKernel : public GpuKernel { input_size_(0), output_size_(0), workspace_size_(0) {} - ~ReLUGpuFwdKernel() override { DestroyResource(); } + ~ActivationGpuFwdKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } @@ -54,33 +56,39 @@ class ReLUGpuFwdKernel : public GpuKernel { const float beta = 0; CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, &beta, data_descriptor_, output), - "ReLUGpuFwdKernel failed"); + "cudnnActivationForward failed"); return true; } bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + InitResource(); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(input_shape); if (is_null_input_) { - MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null."; + MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; InitSizeLists(); return true; } - mode_ = CUDNN_ACTIVATION_RELU; std::vector shape; ShapeNdTo4d(input_shape, &shape); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), - "SetActivationDescriptor failed"); + "cudnnSetActivationDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, shape[0], shape[1], shape[2], shape[3]), - "SetTensor4dDescriptor failed"); + "cudnnSetTensor4dDescriptor failed"); InitSizeLists(); return true; } @@ -110,6 +118,11 @@ class ReLUGpuFwdKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); } + std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, + {"Tanh", CUDNN_ACTIVATION_TANH}, + {"ELU", CUDNN_ACTIVATION_ELU}, + {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; + cudnnHandle_t cudnn_handle_; cudnnActivationDescriptor_t activation_desc_; cudnnActivationMode_t mode_; diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc similarity index 67% rename from mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.cc rename to mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc index 9e2897a6d1..506d2268f7 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc @@ -14,17 +14,26 @@ * limitations under the License. */ -#include "kernel/gpu/nn/relu_grad_kernel.h" +#include "kernel/gpu/nn/activation_grad_kernel.h" namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReluGradGpuKernel, float) + ActivationGradGpuKernel, float) MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ReluGradGpuKernel, half) + ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h similarity index 79% rename from mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h rename to mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h index 899f0752f7..38e34eb752 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h @@ -18,6 +18,8 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ #include +#include +#include #include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/kernel_constants.h" @@ -25,9 +27,9 @@ namespace mindspore { namespace kernel { template -class ReluGradGpuKernel : public GpuKernel { +class ActivationGradGpuKernel : public GpuKernel { public: - ReluGradGpuKernel() + ActivationGradGpuKernel() : cudnn_handle_(nullptr), activation_desc_(nullptr), mode_(CUDNN_ACTIVATION_RELU), @@ -35,7 +37,7 @@ class ReluGradGpuKernel : public GpuKernel { is_null_input_(false), cudnn_data_type_(CUDNN_DATA_FLOAT), input_size_(0) {} - ~ReluGradGpuKernel() override { DestroyResource(); } + ~ActivationGradGpuKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } @@ -45,8 +47,15 @@ class ReluGradGpuKernel : public GpuKernel { if (is_null_input_) { return true; } - T *y = GetDeviceAddress(inputs, 1); - T *dy = GetDeviceAddress(inputs, 0); + T *dy = nullptr; + T *y = nullptr; + if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { + dy = GetDeviceAddress(inputs, 0); + y = GetDeviceAddress(inputs, 1); + } else { + y = GetDeviceAddress(inputs, 0); + dy = GetDeviceAddress(inputs, 1); + } T *dx = GetDeviceAddress(outputs, 0); const float alpha = 1; @@ -59,18 +68,24 @@ class ReluGradGpuKernel : public GpuKernel { return true; } bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + InitResource(); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuKernel needs 2."; + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; return false; } auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - mode_ = CUDNN_ACTIVATION_RELU; is_null_input_ = CHECK_NULL_INPUT(input_shape); if (is_null_input_) { - MS_LOG(WARNING) << "ReluGradGpuKernel input is null."; + MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; InitSizeLists(); return true; } @@ -110,6 +125,10 @@ class ReluGradGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); } + std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, + {"TanhGrad", CUDNN_ACTIVATION_TANH}, + {"ELUGrad", CUDNN_ACTIVATION_ELU}, + {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; cudnnHandle_t cudnn_handle_; cudnnActivationDescriptor_t activation_desc_; cudnnActivationMode_t mode_; diff --git a/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.cc deleted file mode 100644 index 727dffeedb..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/tanh_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TanhGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.h deleted file mode 100644 index 7060ad1792..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/tanh_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class TanhGpuKernel : public GpuKernel { - public: - TanhGpuKernel() : input_size_(0) {} - ~TanhGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto x_addr = GetDeviceAddress(inputs, 0); - auto y_addr = GetDeviceAddress(outputs, 0); - - Tanh(input_size_ / sizeof(T), x_addr, y_addr, reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - - input_size_ = sizeof(T); - for (auto dim : input_shape) { - input_size_ *= dim; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.cc deleted file mode 100644 index 97176680d0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/tanh_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - TanhGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TanhGradKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.h deleted file mode 100644 index b5b52d0acf..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/tanh_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class TanhGradKernel : public GpuKernel { - public: - TanhGradKernel() : input_size_(0) {} - ~TanhGradKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto y_addr = GetDeviceAddress(inputs, 0); - auto dy_addr = GetDeviceAddress(inputs, 1); - auto dx_addr = GetDeviceAddress(outputs, 0); - - TanhGrad(input_size_ / sizeof(T), y_addr, dy_addr, dx_addr, reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - - input_size_ = sizeof(T); - for (auto dim : input_shape) { - input_size_ *= dim; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_tanh_op.py b/tests/st/ops/gpu/test_tanh_op.py index 2e9fa8811d..065bf50f08 100644 --- a/tests/st/ops/gpu/test_tanh_op.py +++ b/tests/st/ops/gpu/test_tanh_op.py @@ -72,3 +72,40 @@ def test_Tanh(): [1.78391056, 0.44159236, 0.33690308, 0.16800483, -0.13651318, -0.63878956, 0.18175511, 0.65280384]] assert np.allclose(output[0].asnumpy(), expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_Tanh_fp16(): + np.random.seed(42) + x_np = np.random.randn(5, 3, 6).astype(np.float16) + dy_np = np.random.randn(5, 3, 6).astype(np.float16) + + x_ms = Tensor(x_np) + dy_ms = Tensor(dy_np) + + net = TanhNet() + grad = Grad(net) + output = grad(x_ms, dy_ms) + + expect = [[[0.0766, 0.95, -0.474, -0.0568, -0.3713, -1.387], + [0.04626, 0.1521, 0.004135, -0.1771, -1.149, -0.341], + [-0.3235, -0.0666, -0.01921, 0.299, 0.7764, 0.1583]], + + [[0.124, -0.0157, -0.3682, -0.0252, 0.05997, 0.51], + [-0.145, 0.2979, -0.01145, -1.019, 0.8125, 0.6914], + [0.562, -0.0848, 1.402, -0.5386, 0.318, 0.645]], + + [[-0.9487, -0.04343, 0.02448, -0.4844, -0.939, 0.0666], + [-1.049, 0.433, -0.1724, 0.9604, -0.6377, -0.1241], + [0.7246, -0.1364, 0.2051, 1.132, -1.049, 0.1298]], + + [[0.104, 0.3643, -0.6562, -1.202, 0.4688, 0.1294], + [0.2008, 0.3347, -0.2418, 0.07135, 0.1611, -0.1667], + [1.856, 0.1979, -1.048, 0.4443, -0.8574, 0.1329]], + + [[1.156, -0.1322, 0.02069, 0.2241, 0.8164, 1.736], + [-0.2433, -0.05484, -0.848, -0.7197, -0.01453, 0.2637], + [0.1528, 0.6494, 0.006195, 1.307, -0.2024, 2.113]]] + + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3, atol=1e-3)