| @@ -178,22 +178,29 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| auto window = GetAttr<std::vector<int>>(kernel_node, "ksize"); | auto window = GetAttr<std::vector<int>>(kernel_node, "ksize"); | ||||
| int window_height = window[2]; | int window_height = window[2]; | ||||
| int window_width = window[3]; | int window_width = window[3]; | ||||
| int stride_h = stride_[2]; | |||||
| int stride_w = stride_[3]; | |||||
| if (data_format_ == kOpFormat_NHWC) { | |||||
| window_height = window[1]; | |||||
| window_width = window[2]; | |||||
| stride_h = stride_[1]; | |||||
| stride_w = stride_[2]; | |||||
| } | |||||
| int windowDimA[2] = {window_height, window_width}; | int windowDimA[2] = {window_height, window_width}; | ||||
| int paddingA[2] = {0, 0}; | int paddingA[2] = {0, 0}; | ||||
| int strideA[2] = {stride_[2], stride_[3]}; | |||||
| int strideA[2] = {stride_h, stride_w}; | |||||
| if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { | if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { | ||||
| pad_height_ = | pad_height_ = | ||||
| std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) | |||||
| : (old_height_ / stride_[2]) + 1) - | |||||
| std::max<int>(0, (((old_height_ / stride_h) * stride_h == old_height_ ? (old_height_ / stride_h) | |||||
| : (old_height_ / stride_h) + 1) - | |||||
| 1) * | 1) * | ||||
| stride_[2] + | |||||
| stride_h + | |||||
| window_height - old_height_); | window_height - old_height_); | ||||
| pad_width_ = | |||||
| std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) | |||||
| : (old_width_ / stride_[3]) + 1) - | |||||
| 1) * | |||||
| stride_[3] + | |||||
| window_width - old_width_); | |||||
| pad_width_ = std::max<int>( | |||||
| 0, (((old_width_ / stride_w) * stride_w == old_width_ ? (old_width_ / stride_w) : (old_width_ / stride_w) + 1) - | |||||
| 1) * | |||||
| stride_w + | |||||
| window_width - old_width_); | |||||
| pad_top_ = pad_height_ / 2; | pad_top_ = pad_height_ / 2; | ||||
| pad_left_ = pad_width_ / 2; | pad_left_ = pad_width_ / 2; | ||||
| paddingA[0] = pad_top_; | paddingA[0] = pad_top_; | ||||