Browse Source

!10410 add dtype supports[fp64/fp32/fp16/int8/int16/int32/int64/uint8] for relu/reluv2/relugradv2

From: @yuan_shen_zhou
Reviewed-by: @wilfchen,@liangchenghui,@linqingke
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f9dbebd958
14 changed files with 225 additions and 196 deletions
  1. +0
    -37
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu
  2. +0
    -23
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh
  3. +17
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu
  4. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h
  5. +1
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc
  6. +13
    -18
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h
  7. +14
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc
  8. +11
    -17
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h
  9. +38
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.cc
  10. +98
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h
  11. +16
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc
  12. +14
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc
  13. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h
  14. +0
    -84
      tests/st/ops/gpu/test_relu_grad_op.py

+ 0
- 37
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu View File

@@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"

template <typename T>
__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0);
}
}

template <typename T>
void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx);
return;
}

template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);

+ 0
- 23
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh View File

@@ -1,23 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_

#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_

+ 17
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu View File

@@ -31,11 +31,14 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream)
return;
}

template void CalReLU(int size, double *input_addr, double *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream);

template <typename T>
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) {
@@ -69,14 +72,26 @@ void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cuda
ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx);
}

template void ReluV2(const size_t num, const double *x, double *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream);

template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx,
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int8_t *dy, const uint32_t *mask, int8_t *dx,
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int16_t *dy, const uint32_t *mask, int16_t *dx,
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx,
cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);

+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h View File

@@ -46,7 +46,8 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F;
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {
{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF},
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32},
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}};
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8},
{"kNumberTypeUInt8", CUDNN_DATA_UINT8}};
// Used by mixprecision, cuda dtype select
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},
{"kNumberTypeFloat16", CUDA_R_16F}};


+ 1
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,15 +18,6 @@

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ActivationGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGpuFwdKernel, int32_t)

MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),


+ 13
- 18
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_

#include <vector>
#include <map>
@@ -44,17 +44,12 @@ class ActivationGpuFwdKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);

if (mode_ == CUDNN_ACTIVATION_RELU) {
const int size = input_size_ / sizeof(T);
CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_,
input, &beta, data_descriptor_, output),
"cudnnActivationForward failed");
}
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
&beta, data_descriptor_, output),
"cudnnActivationForward failed");

return true;
}
@@ -125,7 +120,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
activation_desc_ = nullptr;
mode_ = CUDNN_ACTIVATION_RELU;
mode_ = CUDNN_ACTIVATION_SIGMOID;
data_descriptor_ = nullptr;
is_null_input_ = false;
input_size_list_.clear();
@@ -154,11 +149,11 @@ class ActivationGpuFwdKernel : public GpuKernel {
}
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}

private:
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU},
{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
{"Tanh", CUDNN_ACTIVATION_TANH},
{"Elu", CUDNN_ACTIVATION_ELU},
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};
@@ -179,4 +174,4 @@ class ActivationGpuFwdKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_

+ 14
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,6 +18,10 @@

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ActivationGradGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
@@ -26,12 +30,21 @@ MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ActivationGradGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGradGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ActivationGradGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ActivationGradGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ActivationGradGpuKernel, uint8_t)

MS_REG_GPU_KERNEL_ONE(
ReLU6Grad,


+ 11
- 17
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_

#include <vector>
#include <map>
@@ -23,7 +23,6 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"

namespace mindspore {
namespace kernel {
@@ -52,18 +51,13 @@ class ActivationGradGpuKernel : public GpuKernel {
}
T *dx = GetDeviceAddress<T>(outputs, 0);

if (mode_ == CUDNN_ACTIVATION_RELU) {
const int size = input_size_ / sizeof(T);
CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed");
}
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed");

return true;
}
@@ -179,4 +173,4 @@ class ActivationGradGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_

+ 38
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.cc View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ReLUGpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReLUGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ReLUGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReLUGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReLUGpuFwdKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ReLUGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel,
int8_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ReLUGpuFwdKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

+ 98
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h View File

@@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_

#include <vector>
#include <map>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class ReLUGpuFwdKernel : public GpuKernel {
public:
ReLUGpuFwdKernel() { ResetResource(); }
~ReLUGpuFwdKernel() override {}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);

const int size = input_size_ / sizeof(T);
CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1.";
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null.";
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
size *= input_shape[i];
}
input_size_ = size * sizeof(T);

InitSizeLists();
return true;
}

void ResetResource() noexcept override {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
input_size_ = 0;
workspace_size_ = 0;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
workspace_size_list_.push_back(workspace_size_);
}

private:
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_

+ 16
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc View File

@@ -18,6 +18,10 @@

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
ReluGradV2GpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
@@ -26,6 +30,13 @@ MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
ReluGradV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8),
ReluGradV2GpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16),
ReluGradV2GpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
@@ -34,5 +45,10 @@ MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
ReluGradV2GpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8),
ReluGradV2GpuKernel, uint8_t)

} // namespace kernel
} // namespace mindspore

+ 14
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc View File

@@ -18,6 +18,10 @@

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
@@ -26,12 +30,20 @@ MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h View File

@@ -79,4 +79,4 @@ class ReluV2GpuKernel : public GpuKernel {
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_

+ 0
- 84
tests/st/ops/gpu/test_relu_grad_op.py View File

@@ -1,84 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G


class NetReluGrad(nn.Cell):
def __init__(self):
super(NetReluGrad, self).__init__()
self.rekuGrad = G.ReluGrad()

def construct(self, x, dy):
return self.rekuGrad(dy, x)


def relu_grad_base(dtype):
x = Tensor(np.array([[[[-1, 1, 1],
[1, -1, 1],
[1, 1, -1]]]]).astype(dtype))
dy = Tensor(np.array([[[[1, 0, 1],
[0, 1, 0],
[1, 1, 1]]]]).astype(dtype))
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype)
error = np.ones(shape=[3, 3]) * 1.0e-6

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
relu_grad = NetReluGrad()
output = relu_grad(x, dy)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert output.asnumpy().dtype == dtype


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_float16():
relu_grad_base(np.float16)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_float32():
relu_grad_base(np.float32)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int8():
relu_grad_base(np.int8)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int32():
relu_grad_base(np.int32)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int64():
relu_grad_base(np.int64)

Loading…
Cancel
Save