Browse Source

!1759 Gpu Dropout kernel fix

Merge pull request !1759 from chenweifeng/dropout
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fd1994b66a
8 changed files with 26 additions and 24 deletions
  1. +5
    -5
      mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu
  2. +2
    -2
      mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh
  3. +3
    -3
      mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
  4. +1
    -1
      mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h
  5. +3
    -3
      mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc
  6. +1
    -1
      mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h
  7. +1
    -1
      mindspore/ops/_grad/grad_nn_ops.py
  8. +10
    -8
      mindspore/ops/operations/nn_ops.py

+ 5
- 5
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu View File

@@ -19,10 +19,10 @@
#include "include/cuda_runtime.h" #include "include/cuda_runtime.h"


__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
float drop_prob) {
float scale = 1.f / drop_prob;
float keep_prob) {
float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
mask[i] = mask[i] > drop_prob;
mask[i] = mask[i] <= keep_prob;
output[i] = scale * input[i] * mask[i]; output[i] = scale * input[i] * mask[i];
} }
} }
@@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c
} }


__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
float drop_prob) {
float scale = 1.f / (1.f - drop_prob);
float keep_prob) {
float scale = 1.f / keep_prob;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
dx[i] = scale * dy[i] * mask[i]; dx[i] = scale * dy[i] * mask[i];
} }


+ 2
- 2
mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh View File

@@ -18,9 +18,9 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_


#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float keep_prob,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float keep_prob,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);


#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_

+ 3
- 3
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc View File

@@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel()
: cudnn_handle_(nullptr), : cudnn_handle_(nullptr),
is_null_input_(false), is_null_input_(false),
num_count_(0), num_count_(0),
drop_prob_(0.0),
keep_prob_(0.0),
states_init_(false), states_init_(false),
mask_generator_(nullptr) {} mask_generator_(nullptr) {}


@@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for (size_t x : input_shape) { for (size_t x : input_shape) {
num_count_ *= x; num_count_ *= x;
} }
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob"));
keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));


InitSizeLists(); InitSizeLists();
return true; return true;
@@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const st
} }


curandGenerateUniform(mask_generator_, mask, num_count_); curandGenerateUniform(mask_generator_, mask, num_count_);
DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
DropoutForward(input, mask, output, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));


return true; return true;
} }


+ 1
- 1
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h View File

@@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel {
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
bool is_null_input_; bool is_null_input_;
size_t num_count_; size_t num_count_;
float drop_prob_;
float keep_prob_;
bool states_init_; bool states_init_;
curandGenerator_t mask_generator_; curandGenerator_t mask_generator_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;


+ 3
- 3
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc View File

@@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel()
: cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {}
: cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {}


DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); }


@@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) {
for (size_t x : input_shape) { for (size_t x : input_shape) {
num_count_ *= x; num_count_ *= x;
} }
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob"));
keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob"));


InitSizeLists(); InitSizeLists();
return true; return true;
@@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons
auto *mask = reinterpret_cast<float *>(inputs[1]->addr); auto *mask = reinterpret_cast<float *>(inputs[1]->addr);
auto *dx = reinterpret_cast<float *>(outputs[0]->addr); auto *dx = reinterpret_cast<float *>(outputs[0]->addr);


DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));


return true; return true;
} }


+ 1
- 1
mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h View File

@@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel {
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
bool is_null_input_; bool is_null_input_;
size_t num_count_; size_t num_count_;
float drop_prob_;
float keep_prob_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;


+ 1
- 1
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self):
@bprop_getters.register(P.Dropout) @bprop_getters.register(P.Dropout)
def get_bprop_dropout(self): def get_bprop_dropout(self):
"""Grad definition for `Dropout` operation.""" """Grad definition for `Dropout` operation."""
grad = P.DropoutGrad(self.drop_prob)
grad = P.DropoutGrad(self.keep_prob)


def bprop(x, out, dout): def bprop(x, out, dout):
_, mask = out _, mask = out


+ 10
- 8
mindspore/ops/operations/nn_ops.py View File

@@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer):
During training, randomly zeroes some of the elements of the input tensor with probability. During training, randomly zeroes some of the elements of the input tensor with probability.


Args: Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.


Inputs: Inputs:
- **shape** (tuple[int]) - The shape of target mask. - **shape** (tuple[int]) - The shape of target mask.
@@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape. Tensor, the value of generated mask for input shape.


Examples: Examples:
>>> dropout = P.Dropout(drop_prob=0.5)
>>> dropout = P.Dropout(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50)) >>> in = Tensor((20, 16, 50, 50))
>>> out = dropout(in) >>> out = dropout(in)
""" """


@prim_attr_register @prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name)
@@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer):
of the input tensor with probability. of the input tensor with probability.


Args: Args:
drop_prob (float): probability of an element to be zeroed. Default: 0.
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.


Inputs: Inputs:
- **shape** (tuple[int]) - The shape of target mask. - **shape** (tuple[int]) - The shape of target mask.
@@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer):
Tensor, the value of generated mask for input shape. Tensor, the value of generated mask for input shape.


Examples: Examples:
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5)
>>> dropout_grad = P.DropoutGrad(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50)) >>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in) >>> out = dropout_grad(in)
""" """


@prim_attr_register @prim_attr_register
def __init__(self, drop_prob=0):
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name)
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)


def infer_shape(self, dy_shape, mask_shape): def infer_shape(self, dy_shape, mask_shape):
return dy_shape return dy_shape


Loading…
Cancel
Save