From 5ac5c4b650525cd6e5358f1a60cd8c9b47dd8623 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Fri, 30 Oct 2020 18:17:31 +0800 Subject: [PATCH] fix gpu maxpoolgrad --- .../gpu/nn/pooling_grad_gpu_kernel.h | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index d271a23b97..c8354087b1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -178,22 +178,29 @@ class PoolingGradGpuKernel : public GpuKernel { auto window = GetAttr>(kernel_node, "ksize"); int window_height = window[2]; int window_width = window[3]; + int stride_h = stride_[2]; + int stride_w = stride_[3]; + if (data_format_ == kOpFormat_NHWC) { + window_height = window[1]; + window_width = window[2]; + stride_h = stride_[1]; + stride_w = stride_[2]; + } int windowDimA[2] = {window_height, window_width}; int paddingA[2] = {0, 0}; - int strideA[2] = {stride_[2], stride_[3]}; + int strideA[2] = {stride_h, stride_w}; if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - + std::max(0, (((old_height_ / stride_h) * stride_h == old_height_ ? (old_height_ / stride_h) + : (old_height_ / stride_h) + 1) - 1) * - stride_[2] + + stride_h + window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); + pad_width_ = std::max( + 0, (((old_width_ / stride_w) * stride_w == old_width_ ? (old_width_ / stride_w) : (old_width_ / stride_w) + 1) - + 1) * + stride_w + + window_width - old_width_); pad_top_ = pad_height_ / 2; pad_left_ = pad_width_ / 2; paddingA[0] = pad_top_;