Browse Source

!2150 Gpu Tanh kernel support fp16

Merge pull request !2150 from chenweifeng/tanh-fp16
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
19e66f06e2
11 changed files with 106 additions and 298 deletions
  1. +0
    -46
      mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cu
  2. +0
    -28
      mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cuh
  3. +8
    -3
      mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc
  4. +22
    -9
      mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h
  5. +12
    -3
      mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc
  6. +27
    -8
      mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h
  7. +0
    -24
      mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.cc
  8. +0
    -75
      mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.h
  9. +0
    -26
      mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.cc
  10. +0
    -76
      mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.h
  11. +37
    -0
      tests/st/ops/gpu/test_tanh_op.py

+ 0
- 46
mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cu View File

@@ -1,46 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/cuda_impl/tanh_impl.cuh"
#include <cuda_runtime.h>

template<typename T>
__global__ void TanhKernel(const size_t size, const T* x_addr, T* y_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
y_addr[pos] = tanh(x_addr[pos]);
}
}

template<typename T>
__global__ void TanhGradKernel(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
dx_addr[pos] = dy_addr[pos] * (1 - y_addr[pos] * y_addr[pos]);
}
}

template<typename T>
void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream) {
TanhKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x_addr, y_addr);
}

template<typename T>
void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream) {
TanhGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, y_addr, dy_addr, dx_addr);
}

template void Tanh(const size_t size, const float* x_addr, float* y_addr, cudaStream_t cuda_stream);
template void TanhGrad(const size_t size, const float* y_addr, const float* dy_addr,
float* dx_addr, cudaStream_t cuda_stream);

+ 0
- 28
mindspore/ccsrc/kernel/gpu/cuda_impl/tanh_impl.cuh View File

@@ -1,28 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_

#include "device/gpu/cuda_common.h"

template<typename T>
void Tanh(const size_t size, const T* x_addr, T* y_addr, cudaStream_t cuda_stream);

template<typename T>
void TanhGrad(const size_t size, const T* y_addr, const T* dy_addr, T* dx_addr, cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TAN_H_

mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.cc → mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc View File

@@ -14,13 +14,18 @@
* limitations under the License. * limitations under the License.
*/ */


#include "kernel/gpu/nn/relu_gpu_kernel.h"
#include "kernel/gpu/nn/activation_gpu_kernel.h"


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReLUGpuFwdKernel, float)
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ReLUGpuFwdKernel, half)
ActivationGpuFwdKernel, half)

MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h → mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h View File

@@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_


#include <vector> #include <vector>
#include <map>
#include <string>
#include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h" #include "kernel/gpu/kernel_constants.h"
@@ -25,9 +27,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class ReLUGpuFwdKernel : public GpuKernel {
class ActivationGpuFwdKernel : public GpuKernel {
public: public:
ReLUGpuFwdKernel()
ActivationGpuFwdKernel()
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
activation_desc_(nullptr), activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU), mode_(CUDNN_ACTIVATION_RELU),
@@ -37,7 +39,7 @@ class ReLUGpuFwdKernel : public GpuKernel {
input_size_(0), input_size_(0),
output_size_(0), output_size_(0),
workspace_size_(0) {} workspace_size_(0) {}
~ReLUGpuFwdKernel() override { DestroyResource(); }
~ActivationGpuFwdKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
@@ -54,33 +56,39 @@ class ReLUGpuFwdKernel : public GpuKernel {
const float beta = 0; const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
&beta, data_descriptor_, output), &beta, data_descriptor_, output),
"ReLUGpuFwdKernel failed");
"cudnnActivationForward failed");


return true; return true;
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kernel_map.find(node_name);
if (iter == kernel_map.end()) {
MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support.";
}
mode_ = iter->second;

InitResource(); InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1.";
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1.";
return false; return false;
} }
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape); is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null.";
MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null.";
InitSizeLists(); InitSizeLists();
return true; return true;
} }
mode_ = CUDNN_ACTIVATION_RELU;
std::vector<int> shape; std::vector<int> shape;
ShapeNdTo4d(input_shape, &shape); ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
"SetActivationDescriptor failed");
"cudnnSetActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]), shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
"cudnnSetTensor4dDescriptor failed");
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@@ -110,6 +118,11 @@ class ReLUGpuFwdKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
} }


std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU},
{"Tanh", CUDNN_ACTIVATION_TANH},
{"ELU", CUDNN_ACTIVATION_ELU},
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};

cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_; cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_; cudnnActivationMode_t mode_;

mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.cc → mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc View File

@@ -14,17 +14,26 @@
* limitations under the License. * limitations under the License.
*/ */


#include "kernel/gpu/nn/relu_grad_kernel.h"
#include "kernel/gpu/nn/activation_grad_kernel.h"


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
ReluGrad, ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReluGradGpuKernel, float)
ActivationGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
ReluGrad, ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ReluGradGpuKernel, half)
ActivationGradGpuKernel, half)

MS_REG_GPU_KERNEL_ONE(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernel, half)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h → mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h View File

@@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_


#include <vector> #include <vector>
#include <map>
#include <string>
#include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h" #include "kernel/gpu/kernel_constants.h"
@@ -25,9 +27,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class ReluGradGpuKernel : public GpuKernel {
class ActivationGradGpuKernel : public GpuKernel {
public: public:
ReluGradGpuKernel()
ActivationGradGpuKernel()
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
activation_desc_(nullptr), activation_desc_(nullptr),
mode_(CUDNN_ACTIVATION_RELU), mode_(CUDNN_ACTIVATION_RELU),
@@ -35,7 +37,7 @@ class ReluGradGpuKernel : public GpuKernel {
is_null_input_(false), is_null_input_(false),
cudnn_data_type_(CUDNN_DATA_FLOAT), cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0) {} input_size_(0) {}
~ReluGradGpuKernel() override { DestroyResource(); }
~ActivationGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
@@ -45,8 +47,15 @@ class ReluGradGpuKernel : public GpuKernel {
if (is_null_input_) { if (is_null_input_) {
return true; return true;
} }
T *y = GetDeviceAddress<T>(inputs, 1);
T *dy = GetDeviceAddress<T>(inputs, 0);
T *dy = nullptr;
T *y = nullptr;
if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) {
dy = GetDeviceAddress<T>(inputs, 0);
y = GetDeviceAddress<T>(inputs, 1);
} else {
y = GetDeviceAddress<T>(inputs, 0);
dy = GetDeviceAddress<T>(inputs, 1);
}
T *dx = GetDeviceAddress<T>(outputs, 0); T *dx = GetDeviceAddress<T>(outputs, 0);


const float alpha = 1; const float alpha = 1;
@@ -59,18 +68,24 @@ class ReluGradGpuKernel : public GpuKernel {
return true; return true;
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kernel_map.find(node_name);
if (iter == kernel_map.end()) {
MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support.";
}
mode_ = iter->second;

InitResource(); InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuKernel needs 2.";
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2.";
return false; return false;
} }
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
mode_ = CUDNN_ACTIVATION_RELU;
is_null_input_ = CHECK_NULL_INPUT(input_shape); is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) { if (is_null_input_) {
MS_LOG(WARNING) << "ReluGradGpuKernel input is null.";
MS_LOG(WARNING) << "ActivationGradGpuKernel input is null.";
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@@ -110,6 +125,10 @@ class ReluGradGpuKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed");
} }


std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU},
{"TanhGrad", CUDNN_ACTIVATION_TANH},
{"ELUGrad", CUDNN_ACTIVATION_ELU},
{"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}};
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cudnnActivationDescriptor_t activation_desc_; cudnnActivationDescriptor_t activation_desc_;
cudnnActivationMode_t mode_; cudnnActivationMode_t mode_;

+ 0
- 24
mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.cc View File

@@ -1,24 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/nn/tanh_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TanhGpuKernel, float)
} // namespace kernel
} // namespace mindspore

+ 0
- 75
mindspore/ccsrc/kernel/gpu/nn/tanh_gpu_kernel.h View File

@@ -1,75 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GPU_KERNEL_H_

#include <cuda_runtime_api.h>
#include <vector>
#include <memory>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/tanh_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class TanhGpuKernel : public GpuKernel {
public:
TanhGpuKernel() : input_size_(0) {}
~TanhGpuKernel() override = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto x_addr = GetDeviceAddress<T>(inputs, 0);
auto y_addr = GetDeviceAddress<T>(outputs, 0);

Tanh(input_size_ / sizeof(T), x_addr, y_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);

input_size_ = sizeof(T);
for (auto dim : input_shape) {
input_size_ *= dim;
}

InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
}

private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_

+ 0
- 26
mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.cc View File

@@ -1,26 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/nn/tanh_grad_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TanhGradKernel, float)
} // namespace kernel
} // namespace mindspore

+ 0
- 76
mindspore/ccsrc/kernel/gpu/nn/tanh_grad_kernel.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_

#include <cuda_runtime_api.h>
#include <vector>
#include <memory>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/tanh_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class TanhGradKernel : public GpuKernel {
public:
TanhGradKernel() : input_size_(0) {}
~TanhGradKernel() override = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto y_addr = GetDeviceAddress<T>(inputs, 0);
auto dy_addr = GetDeviceAddress<T>(inputs, 1);
auto dx_addr = GetDeviceAddress<T>(outputs, 0);

TanhGrad(input_size_ / sizeof(T), y_addr, dy_addr, dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);

input_size_ = sizeof(T);
for (auto dim : input_shape) {
input_size_ *= dim;
}

InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
}

private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_TANH_GRAD_KERNEL_H_

+ 37
- 0
tests/st/ops/gpu/test_tanh_op.py View File

@@ -72,3 +72,40 @@ def test_Tanh():
[1.78391056, 0.44159236, 0.33690308, 0.16800483, -0.13651318, -0.63878956, 0.18175511, 0.65280384]] [1.78391056, 0.44159236, 0.33690308, 0.16800483, -0.13651318, -0.63878956, 0.18175511, 0.65280384]]


assert np.allclose(output[0].asnumpy(), expect) assert np.allclose(output[0].asnumpy(), expect)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_Tanh_fp16():
np.random.seed(42)
x_np = np.random.randn(5, 3, 6).astype(np.float16)
dy_np = np.random.randn(5, 3, 6).astype(np.float16)

x_ms = Tensor(x_np)
dy_ms = Tensor(dy_np)

net = TanhNet()
grad = Grad(net)
output = grad(x_ms, dy_ms)

expect = [[[0.0766, 0.95, -0.474, -0.0568, -0.3713, -1.387],
[0.04626, 0.1521, 0.004135, -0.1771, -1.149, -0.341],
[-0.3235, -0.0666, -0.01921, 0.299, 0.7764, 0.1583]],

[[0.124, -0.0157, -0.3682, -0.0252, 0.05997, 0.51],
[-0.145, 0.2979, -0.01145, -1.019, 0.8125, 0.6914],
[0.562, -0.0848, 1.402, -0.5386, 0.318, 0.645]],

[[-0.9487, -0.04343, 0.02448, -0.4844, -0.939, 0.0666],
[-1.049, 0.433, -0.1724, 0.9604, -0.6377, -0.1241],
[0.7246, -0.1364, 0.2051, 1.132, -1.049, 0.1298]],

[[0.104, 0.3643, -0.6562, -1.202, 0.4688, 0.1294],
[0.2008, 0.3347, -0.2418, 0.07135, 0.1611, -0.1667],
[1.856, 0.1979, -1.048, 0.4443, -0.8574, 0.1329]],

[[1.156, -0.1322, 0.02069, 0.2241, 0.8164, 1.736],
[-0.2433, -0.05484, -0.848, -0.7197, -0.01453, 0.2637],
[0.1528, 0.6494, 0.006195, 1.307, -0.2024, 2.113]]]

assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3, atol=1e-3)

Loading…
Cancel
Save