| @@ -85,10 +85,10 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | ||||
| auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); | auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); | ||||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | ||||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||||
| auto format_attr = GetAttr<std::string>(kernel_node, "data_format"); | |||||
| if (format_attr == kOpFormat_NHWC) { | |||||
| data_format_ = kOpFormat_NHWC; | |||||
| auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||||
| format_attr_ = GetAttr<std::string>(kernel_node, "data_format"); | |||||
| if (format_attr_ == kOpFormat_NHWC) { | |||||
| data_format = kOpFormat_NHWC; | |||||
| } | } | ||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | ||||
| is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); | is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); | ||||
| @@ -97,7 +97,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||||
| SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format); | |||||
| const int nbDims = 4; | const int nbDims = 4; | ||||
| int dimA[4]; | int dimA[4]; | ||||
| int strideAin[4]; | int strideAin[4]; | ||||
| @@ -107,14 +107,14 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| int strideAdy[4]; | int strideAdy[4]; | ||||
| int dimAout[4]; | int dimAout[4]; | ||||
| int strideAout[4]; | int strideAout[4]; | ||||
| SetDimA(input_shape, dimA, 4, data_format_); | |||||
| SetStrideA(input_shape, strideAin, 4, data_format_); | |||||
| SetDimA(input_mask, dimAy, 4, data_format_); | |||||
| SetStrideA(input_mask, strideAiny, 4, data_format_); | |||||
| SetDimA(dout_shape, dimAdy, 4, data_format_); | |||||
| SetStrideA(dout_shape, strideAdy, 4, data_format_); | |||||
| SetDimA(output_shape, dimAout, 4, data_format_); | |||||
| SetStrideA(output_shape, strideAout, 4, data_format_); | |||||
| SetDimA(input_shape, dimA, 4, data_format); | |||||
| SetStrideA(input_shape, strideAin, 4, data_format); | |||||
| SetDimA(input_mask, dimAy, 4, data_format); | |||||
| SetStrideA(input_mask, strideAiny, 4, data_format); | |||||
| SetDimA(dout_shape, dimAdy, 4, data_format); | |||||
| SetStrideA(dout_shape, strideAdy, 4, data_format); | |||||
| SetDimA(output_shape, dimAout, 4, data_format); | |||||
| SetStrideA(output_shape, strideAout, 4, data_format); | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), | ||||
| "cudnnSetTensor4dDescriptor failed"); | "cudnnSetTensor4dDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), | ||||
| @@ -180,7 +180,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| int window_width = window[3]; | int window_width = window[3]; | ||||
| int stride_h = stride_[2]; | int stride_h = stride_[2]; | ||||
| int stride_w = stride_[3]; | int stride_w = stride_[3]; | ||||
| if (data_format_ == kOpFormat_NHWC) { | |||||
| if (format_attr_ == kOpFormat_NHWC) { | |||||
| window_height = window[1]; | window_height = window[1]; | ||||
| window_width = window[2]; | window_width = window[2]; | ||||
| stride_h = stride_[1]; | stride_h = stride_[1]; | ||||
| @@ -247,7 +247,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| std::string mode_; | std::string mode_; | ||||
| std::string pad_mode_; | std::string pad_mode_; | ||||
| std::string data_format_ = kOpFormat_NCHW; | |||||
| std::string format_attr_ = kOpFormat_NCHW; | |||||
| cudnnDataType_t cudnn_data_type_; | cudnnDataType_t cudnn_data_type_; | ||||
| cudnnTensorFormat_t compute_format_; | cudnnTensorFormat_t compute_format_; | ||||
| int old_height_; | int old_height_; | ||||