From 01ef6db9e71971f66a154741e7eb9c02bfab4d7b Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 16 Sep 2020 14:49:33 +0800 Subject: [PATCH] fix reviewbot:function has more than 50 lines --- .../gpu/nn/pooling_gpu_kernel.h | 67 ++++++++-------- .../gpu/nn/pooling_grad_gpu_kernel.h | 77 +++++++++---------- 2 files changed, 72 insertions(+), 72 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h index 42406fd040..6bebca4e81 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -103,38 +103,8 @@ class PoolingGpuFwdKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT( cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), "cudnnSetTensor4dDescriptor failed"); - auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); - int window_height = window[2]; - int window_width = window[3]; - stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); SetPoolingMode(kernel_node); - int windowDimA[2] = {window_height, window_width}; - int paddingA[2] = {0, 0}; - int strideA[2] = {stride_[2], stride_[3]}; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - paddingA[0] = pad_top_; - paddingA[1] = pad_left_; - } else { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - 2, windowDimA, paddingA, strideA), - "cudnnSetPoolingNdDescriptor failed"); + SetPad(kernel_node); InitSizeLists(); return true; } @@ -172,7 +142,6 @@ class PoolingGpuFwdKernel : public GpuKernel { } void SetPoolingMode(const CNodePtr &kernel_node) { - pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); mode_ = AnfAlgo::GetCNodeName(kernel_node); if (mode_ == "AvgPool") { pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; @@ -182,6 +151,40 @@ class PoolingGpuFwdKernel : public GpuKernel { pad_value_ = kSignedMinFloat; } } + void SetPad(const CNodePtr &kernel_node) { + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + int window_height = window[2]; + int window_width = window[3]; + stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + int windowDimA[2] = {window_height, window_width}; + int paddingA[2] = {0, 0}; + int strideA[2] = {stride_[2], stride_[3]}; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + paddingA[0] = pad_top_; + paddingA[1] = pad_left_; + } else { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + 2, windowDimA, paddingA, strideA), + "cudnnSetPoolingNdDescriptor failed"); + } void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), "cudnnDestroyPoolingDescriptor failed"); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index ddc71e7dab..ac78d2cd2a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -81,10 +81,6 @@ class PoolingGradGpuKernel : public GpuKernel { if (!CheckParam(kernel_node)) { return false; } - auto window = GetAttr>(kernel_node, "ksize"); - int window_height = window[2]; - int window_width = window[3]; - SetPoolingMode(kernel_node); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); @@ -97,9 +93,6 @@ class PoolingGradGpuKernel : public GpuKernel { return true; } SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); - int windowDimA[2] = {window_height, window_width}; - int paddingA[2] = {0, 0}; - int strideA[2] = {stride_[2], stride_[3]}; const int nbDims = 4; int dimA[4]; int strideAin[4]; @@ -126,32 +119,8 @@ class PoolingGradGpuKernel : public GpuKernel { "cudnnSetTensor4dDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), "cudnnSetTensor4dDescriptor failed"); - if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - paddingA[0] = pad_top_; - paddingA[1] = pad_left_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - 2, windowDimA, paddingA, strideA), - "cudnnSetPoolingNdDescriptor failed"); + SetPoolingMode(kernel_node); + SetPad(kernel_node); InitSizeLists(); return true; } @@ -198,15 +167,43 @@ class PoolingGradGpuKernel : public GpuKernel { } return true; } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - } - void SetPoolingMode(const CNodePtr &kernel_node) { + void SetPad(const CNodePtr &kernel_node) { pad_mode_ = GetAttr(kernel_node, "padding"); stride_ = GetAttr>(kernel_node, "strides"); + auto window = GetAttr>(kernel_node, "ksize"); + int window_height = window[2]; + int window_width = window[3]; + int windowDimA[2] = {window_height, window_width}; + int paddingA[2] = {0, 0}; + int strideA[2] = {stride_[2], stride_[3]}; + if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + paddingA[0] = pad_top_; + paddingA[1] = pad_left_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + 2, windowDimA, paddingA, strideA), + "cudnnSetPoolingNdDescriptor failed"); + } + void SetPoolingMode(const CNodePtr &kernel_node) { cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); mode_ = AnfAlgo::GetCNodeName(kernel_node); if (mode_ == "AvgPoolGradGpu") {