Merge pull request !1759 from chenweifeng/dropouttags/v0.5.0-beta
| @@ -19,10 +19,10 @@ | |||||
| #include "include/cuda_runtime.h" | #include "include/cuda_runtime.h" | ||||
| __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, | __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, | ||||
| float drop_prob) { | |||||
| float scale = 1.f / drop_prob; | |||||
| float keep_prob) { | |||||
| float scale = 1.f / keep_prob; | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | ||||
| mask[i] = mask[i] > drop_prob; | |||||
| mask[i] = mask[i] <= keep_prob; | |||||
| output[i] = scale * input[i] * mask[i]; | output[i] = scale * input[i] * mask[i]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -34,8 +34,8 @@ void DropoutForward(const float *input, float *mask, float *output, size_t num_c | |||||
| } | } | ||||
| __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, | __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, | ||||
| float drop_prob) { | |||||
| float scale = 1.f / (1.f - drop_prob); | |||||
| float keep_prob) { | |||||
| float scale = 1.f / keep_prob; | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | ||||
| dx[i] = scale * dy[i] * mask[i]; | dx[i] = scale * dy[i] * mask[i]; | ||||
| } | } | ||||
| @@ -18,9 +18,9 @@ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | ||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob, | |||||
| void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float keep_prob, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob, | |||||
| void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float keep_prob, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | ||||
| @@ -23,7 +23,7 @@ DropoutGpuFwdKernel::DropoutGpuFwdKernel() | |||||
| : cudnn_handle_(nullptr), | : cudnn_handle_(nullptr), | ||||
| is_null_input_(false), | is_null_input_(false), | ||||
| num_count_(0), | num_count_(0), | ||||
| drop_prob_(0.0), | |||||
| keep_prob_(0.0), | |||||
| states_init_(false), | states_init_(false), | ||||
| mask_generator_(nullptr) {} | mask_generator_(nullptr) {} | ||||
| @@ -54,7 +54,7 @@ bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) { | |||||
| for (size_t x : input_shape) { | for (size_t x : input_shape) { | ||||
| num_count_ *= x; | num_count_ *= x; | ||||
| } | } | ||||
| drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); | |||||
| keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob")); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -92,7 +92,7 @@ bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const st | |||||
| } | } | ||||
| curandGenerateUniform(mask_generator_, mask, num_count_); | curandGenerateUniform(mask_generator_, mask, num_count_); | ||||
| DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| DropoutForward(input, mask, output, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -52,7 +52,7 @@ class DropoutGpuFwdKernel : public GpuKernel { | |||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| bool is_null_input_; | bool is_null_input_; | ||||
| size_t num_count_; | size_t num_count_; | ||||
| float drop_prob_; | |||||
| float keep_prob_; | |||||
| bool states_init_; | bool states_init_; | ||||
| curandGenerator_t mask_generator_; | curandGenerator_t mask_generator_; | ||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| @@ -20,7 +20,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() | DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() | ||||
| : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {} | |||||
| : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} | |||||
| DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } | DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } | ||||
| @@ -50,7 +50,7 @@ bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) { | |||||
| for (size_t x : input_shape) { | for (size_t x : input_shape) { | ||||
| num_count_ *= x; | num_count_ *= x; | ||||
| } | } | ||||
| drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); | |||||
| keep_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("keep_prob")); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -84,7 +84,7 @@ bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons | |||||
| auto *mask = reinterpret_cast<float *>(inputs[1]->addr); | auto *mask = reinterpret_cast<float *>(inputs[1]->addr); | ||||
| auto *dx = reinterpret_cast<float *>(outputs[0]->addr); | auto *dx = reinterpret_cast<float *>(outputs[0]->addr); | ||||
| DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ class DropoutGradGpuFwdKernel : public GpuKernel { | |||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| bool is_null_input_; | bool is_null_input_; | ||||
| size_t num_count_; | size_t num_count_; | ||||
| float drop_prob_; | |||||
| float keep_prob_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -675,7 +675,7 @@ def get_bprop_binary_cross_entropy(self): | |||||
| @bprop_getters.register(P.Dropout) | @bprop_getters.register(P.Dropout) | ||||
| def get_bprop_dropout(self): | def get_bprop_dropout(self): | ||||
| """Grad definition for `Dropout` operation.""" | """Grad definition for `Dropout` operation.""" | ||||
| grad = P.DropoutGrad(self.drop_prob) | |||||
| grad = P.DropoutGrad(self.keep_prob) | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| _, mask = out | _, mask = out | ||||
| @@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer): | |||||
| During training, randomly zeroes some of the elements of the input tensor with probability. | During training, randomly zeroes some of the elements of the input tensor with probability. | ||||
| Args: | Args: | ||||
| drop_prob (float): probability of an element to be zeroed. Default: 0. | |||||
| keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, | |||||
| means dropping out 10% of input units. | |||||
| Inputs: | Inputs: | ||||
| - **shape** (tuple[int]) - The shape of target mask. | - **shape** (tuple[int]) - The shape of target mask. | ||||
| @@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer): | |||||
| Tensor, the value of generated mask for input shape. | Tensor, the value of generated mask for input shape. | ||||
| Examples: | Examples: | ||||
| >>> dropout = P.Dropout(drop_prob=0.5) | |||||
| >>> dropout = P.Dropout(keep_prob=0.5) | |||||
| >>> in = Tensor((20, 16, 50, 50)) | >>> in = Tensor((20, 16, 50, 50)) | ||||
| >>> out = dropout(in) | >>> out = dropout(in) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, drop_prob=0): | |||||
| self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) | |||||
| def __init__(self, keep_prob=0.5): | |||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | ||||
| @@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer): | |||||
| of the input tensor with probability. | of the input tensor with probability. | ||||
| Args: | Args: | ||||
| drop_prob (float): probability of an element to be zeroed. Default: 0. | |||||
| keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, | |||||
| means dropping out 10% of input units. | |||||
| Inputs: | Inputs: | ||||
| - **shape** (tuple[int]) - The shape of target mask. | - **shape** (tuple[int]) - The shape of target mask. | ||||
| @@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer): | |||||
| Tensor, the value of generated mask for input shape. | Tensor, the value of generated mask for input shape. | ||||
| Examples: | Examples: | ||||
| >>> dropout_grad = P.DropoutGrad(drop_prob=0.5) | |||||
| >>> dropout_grad = P.DropoutGrad(keep_prob=0.5) | |||||
| >>> in = Tensor((20, 16, 50, 50)) | >>> in = Tensor((20, 16, 50, 50)) | ||||
| >>> out = dropout_grad(in) | >>> out = dropout_grad(in) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, drop_prob=0): | |||||
| self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) | |||||
| def __init__(self, keep_prob=0.5): | |||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| def infer_shape(self, dy_shape, mask_shape): | def infer_shape(self, dy_shape, mask_shape): | ||||
| return dy_shape | return dy_shape | ||||