| @@ -37,6 +37,39 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, float pad_value, T* output) { | |||
| T pad_value_ = static_cast<T>(pad_value); | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int block_num = pos / channels / padded_width / padded_height; | |||
| const int padded_w = pos / channels % padded_width; | |||
| const int padded_h = pos / channels / padded_width % padded_height; | |||
| if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || | |||
| padded_w - pad_left >= old_width) { | |||
| output[pos] = pad_value_; | |||
| } else { | |||
| output[pos] = input[((block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left) | |||
| *channels + pos % channels]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, T* dx) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| int block_num = pos / channels / old_width / old_height; | |||
| const int padded_w = pos / channels % old_width + pad_left; | |||
| const int padded_h = pos / channels / old_width % old_height + pad_top; | |||
| dx[pos] = dy[((block_num * padded_height + padded_h) * padded_width + padded_w)*channels+pos%channels]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, | |||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, | |||
| @@ -60,6 +93,24 @@ void CalPad(const size_t size, const T* input, const int num, const int channels | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, const float pad_value, T* output, cudaStream_t cuda_stream) { | |||
| PadNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, old_height, old_width, channels, | |||
| padded_height, padded_width, pad_top, pad_left, pad_value, output); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, T* dx, cudaStream_t cuda_stream) { | |||
| PadGradNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, old_height, old_width, channels, | |||
| padded_height, padded_width, pad_top, pad_left, dx); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, | |||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, | |||
| @@ -85,3 +136,19 @@ template void CalPadGrad<half>(const size_t size, const half* dy, const int num, | |||
| const int old_height, const int old_width, const int padded_height, | |||
| const int padded_width, const int pad_top, const int pad_left, half* dx, | |||
| cudaStream_t cuda_stream); | |||
| template void CalPadNHWC<float>(const size_t size, const float* input, const int num, const int old_height, | |||
| const int old_width, const int channels, const int padded_height, | |||
| const int padded_width, const int pad_top, const int pad_left, float pad_value, | |||
| float* output, cudaStream_t cuda_stream); | |||
| template void CalPadNHWC<half>(const size_t size, const half* input, const int num, const int old_height, | |||
| const int old_width, const int channels, const int padded_height, | |||
| const int padded_width, const int pad_top, const int pad_left, float pad_value, | |||
| half* output, cudaStream_t cuda_stream); | |||
| template void CalPadGradNHWC<float>(const size_t size, const float* dy, const int num, const int old_height, | |||
| const int old_width, const int channels, const int padded_height, | |||
| const int padded_width, const int pad_top, const int pad_left, float* dx, | |||
| cudaStream_t cuda_stream); | |||
| template void CalPadGradNHWC<half>(const size_t size, const half* dy, const int num, const int old_height, | |||
| const int old_width, const int channels, const int padded_height, | |||
| const int padded_width, const int pad_top, const int pad_left, half* dx, | |||
| cudaStream_t cuda_stream); | |||
| @@ -27,5 +27,13 @@ template <typename T> | |||
| void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, | |||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, T* dx, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, | |||
| float pad_value, T* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, | |||
| const int channels, const int padded_height, const int padded_width, const int pad_top, | |||
| const int pad_left, T* output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ | |||
| @@ -21,6 +21,7 @@ | |||
| #include <cudnn.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "runtime/device/gpu/gpu_device_manager.h" | |||
| @@ -73,6 +74,59 @@ class GpuKernel : public KernelMod { | |||
| dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); | |||
| } | |||
| // transpose shape: NCHW To NHWC | |||
| void ShapeNCHW2NHWC(std::vector<size_t> *shape) { | |||
| std::swap((*shape)[1], (*shape)[3]); | |||
| std::swap((*shape)[2], (*shape)[1]); | |||
| } | |||
| void SetDimA(const std::vector<size_t> &shape, int *dimA, const std::string &format) { | |||
| if (format == "NCHW" || format == "DefaultFormat") { | |||
| dimA[0] = SizeToInt(shape[0]); | |||
| dimA[1] = SizeToInt(shape[1]); | |||
| dimA[2] = SizeToInt(shape[2]); | |||
| dimA[3] = SizeToInt(shape[3]); | |||
| } else if (format == "NHWC") { | |||
| dimA[0] = SizeToInt(shape[0]); | |||
| dimA[1] = SizeToInt(shape[3]); | |||
| dimA[2] = SizeToInt(shape[1]); | |||
| dimA[3] = SizeToInt(shape[2]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data format " << format; | |||
| } | |||
| } | |||
| void SetStrideA(const std::vector<size_t> &shape, int *strideA, const std::string &format) { | |||
| if (format == "NCHW" || format == "DefaultFormat") { | |||
| strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]); | |||
| strideA[1] = SizeToInt(shape[2] * shape[3]); | |||
| strideA[2] = SizeToInt(shape[3]); | |||
| strideA[3] = 1; | |||
| } else if (format == "NHWC") { | |||
| strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]); | |||
| strideA[1] = 1; | |||
| strideA[2] = SizeToInt(shape[2] * shape[3]); | |||
| strideA[3] = SizeToInt(shape[3]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data format " << format; | |||
| } | |||
| } | |||
| void SetNCHW(const std::vector<size_t> &shape, int *n, int *c, int *h, int *w, const std::string &format) { | |||
| if (format == "NCHW" || format == "DefaultFormat") { | |||
| *n = SizeToInt(shape[0]); | |||
| *c = SizeToInt(shape[1]); | |||
| *h = SizeToInt(shape[2]); | |||
| *w = SizeToInt(shape[3]); | |||
| } else if (format == "NHWC") { | |||
| *n = SizeToInt(shape[0]); | |||
| *c = SizeToInt(shape[3]); | |||
| *h = SizeToInt(shape[1]); | |||
| *w = SizeToInt(shape[2]); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data format " << format; | |||
| } | |||
| } | |||
| inline void CheckBroadcast4TensorOp(const std::vector<int> &A, const std::vector<int> &B, | |||
| const std::vector<int> &Out) { | |||
| if (A != Out && B != Out) { | |||
| @@ -38,6 +38,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| conv_desc_(nullptr), | |||
| padded_desc_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| @@ -76,9 +77,15 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded_addr = GetDeviceAddress<T>(workspace, 1); | |||
| CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (data_format_ == "NHWC") { | |||
| CalPadNHWC(padded_size_ / sizeof(T), input_addr, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, | |||
| conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), | |||
| @@ -97,15 +104,21 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(in_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||
| if (data_format_ == "NHWC") { | |||
| compute_format_ = CUDNN_TENSOR_NHWC; | |||
| } | |||
| Set4DDesc(in_shape, filter_shape, output_shape); | |||
| group_ = GetAttr<int>(kernel_node, "group"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); | |||
| @@ -116,17 +129,55 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| cudnnTensorDescriptor_t input_descriptor_real = nullptr; | |||
| int padA[2]; | |||
| int strideA[2] = {stride_[2], stride_[3]}; | |||
| int dilaA[2] = {dilation_[2], dilation_[3]}; | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { | |||
| SetPad(in_shape, kernel_node); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| int dimA[4]; | |||
| int strideApadded[4]; | |||
| if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), | |||
| IntToSize(old_width_ + pad_width_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } else if (data_format_ == "NHWC") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), | |||
| IntToSize(c_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| if (use_pad_) { | |||
| padA[0] = 0; | |||
| padA[1] = 0; | |||
| } else { | |||
| padA[0] = pad_top_; | |||
| padA[1] = pad_left_; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| padA[0] = pad_height_; | |||
| padA[1] = pad_width_; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], | |||
| dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| input_descriptor_real = input_desc_; | |||
| } | |||
| @@ -193,13 +244,11 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); | |||
| } | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; | |||
| @@ -207,46 +256,29 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| } | |||
| return true; | |||
| } | |||
| void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) { | |||
| auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list"); | |||
| n_ = SizeToInt(in_shape[0]); | |||
| c_ = SizeToInt(in_shape[1]); | |||
| old_height_ = SizeToInt(in_shape[2]); | |||
| old_width_ = SizeToInt(in_shape[3]); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, | |||
| old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( | |||
| conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], | |||
| dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| } | |||
| void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape, | |||
| const std::vector<size_t> &output_shape) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), | |||
| SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| int nbDims = 4; | |||
| int dimA[4]; | |||
| int strideAin[4]; | |||
| int dimAout[4]; | |||
| int strideAout[4]; | |||
| SetDimA(in_shape, dimA, data_format_); | |||
| SetStrideA(in_shape, strideAin, data_format_); | |||
| SetDimA(output_shape, dimAout, data_format_); | |||
| SetStrideA(output_shape, strideAout, data_format_); | |||
| int filterDimA[4]; | |||
| // OHWI for NHWC; OIHW for NCHW | |||
| SetDimA(filter_shape, filterDimA, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), | |||
| SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), | |||
| cudnnSetFilterNdDescriptor(filter_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), | |||
| "cudnnSetFilter4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), | |||
| SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(output_desc_, cudnn_data_type_, nbDims, dimAout, strideAout), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } | |||
| void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { | |||
| if (group_ > 1 || CUDNN_MAJOR < 7) { | |||
| @@ -292,11 +324,13 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||
| cudnnConvolutionDescriptor_t conv_desc_; | |||
| cudnnTensorDescriptor_t padded_desc_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = "NCHW"; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| const float pad_value_ = 0.0; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| @@ -38,6 +38,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| x_desc_(nullptr), | |||
| padded_descriptor_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| @@ -75,12 +76,18 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 1); | |||
| CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (data_format_ == "NHWC") { | |||
| CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, | |||
| algo_, work_space, workspace_size_, &beta, dw_desc_, dw), | |||
| @@ -99,16 +106,21 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<int> filter_shape; | |||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| std::vector<size_t> filter_shape; | |||
| GetFilterShape(kernel_node, &filter_shape); | |||
| if (data_format_ == "NHWC") { | |||
| compute_format_ = CUDNN_TENSOR_NHWC; | |||
| } | |||
| SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||
| Set4DDesc(dy_shape, filter_shape, in_shape); | |||
| group_ = GetAttr<int>(kernel_node, "group"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); | |||
| @@ -120,18 +132,54 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| cudnnTensorDescriptor_t x_desc_real = nullptr; | |||
| int padA[2]; | |||
| int strideA[2] = {stride_[0], stride_[1]}; | |||
| int dilaA[2] = {dilation_[0], dilation_[1]}; | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { | |||
| SetPad(in_shape, kernel_node); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| int dimA[4]; | |||
| int strideApadded[4]; | |||
| if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), | |||
| IntToSize(old_width_ + pad_width_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } else if (data_format_ == "NHWC") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), | |||
| IntToSize(c_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| if (use_pad_) { | |||
| padA[0] = 0; | |||
| padA[1] = 0; | |||
| } else { | |||
| padA[0] = pad_top_; | |||
| padA[1] = pad_left_; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| padA[0] = pad_height_; | |||
| padA[1] = pad_width_; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], | |||
| dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "GetConvolution2dDescriptor failed"); | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| x_desc_real = x_desc_; | |||
| } | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| @@ -208,27 +256,6 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| } | |||
| return true; | |||
| } | |||
| void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) { | |||
| auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list"); | |||
| n_ = SizeToInt(in_shape[0]); | |||
| c_ = SizeToInt(in_shape[1]); | |||
| old_height_ = SizeToInt(in_shape[2]); | |||
| old_width_ = SizeToInt(in_shape[3]); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, | |||
| c_, old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( | |||
| conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], | |||
| dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| } | |||
| void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { | |||
| if (group_ > 1 || CUDNN_MAJOR < 7) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| @@ -249,27 +276,33 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; | |||
| } | |||
| } | |||
| void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) { | |||
| void GetFilterShape(const CNodePtr &kernel_node, std::vector<size_t> *filter_shape) { | |||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), | |||
| [](const ValuePtr &e) -> int { return e->cast<Int32ImmPtr>()->value(); }); | |||
| [](const ValuePtr &e) -> size_t { return e->cast<Int32ImmPtr>()->value(); }); | |||
| } | |||
| void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<int> &filter_shape, | |||
| void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape, | |||
| const std::vector<size_t> &in_shape) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), | |||
| SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), | |||
| "SetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], | |||
| filter_shape[2], filter_shape[3]), | |||
| "SetFilter4dDescriptor failed"); | |||
| int nbDims = 4; | |||
| int dimA[4]; | |||
| int strideAin[4]; | |||
| int dimAdy[4]; | |||
| int strideAdy[4]; | |||
| SetDimA(in_shape, dimA, data_format_); | |||
| SetStrideA(in_shape, strideAin, data_format_); | |||
| SetDimA(dy_shape, dimAdy, data_format_); | |||
| SetStrideA(dy_shape, strideAdy, data_format_); | |||
| // filter shape always keep OIHW. | |||
| int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), | |||
| SizeToInt(filter_shape[3])}; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), | |||
| SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), | |||
| "SetTensor4dDescriptor failed"); | |||
| cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), | |||
| "cudnnSetFilterNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, nbDims, dimA, strideAin), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| } | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride"); | |||
| @@ -292,11 +325,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnConvolutionBwdFilterAlgo_t algo_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = "NCHW"; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| const float pad_value_ = 0.0; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| @@ -319,4 +354,4 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ | |||
| #endif // MINDePORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ | |||
| @@ -38,6 +38,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| dx_desc_(nullptr), | |||
| padded_descriptor_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| @@ -75,7 +76,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 1); | |||
| @@ -83,8 +83,13 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta, padded_descriptor_, padded), | |||
| "ConvolutionBackwardData failed"); | |||
| CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (data_format_ == "NHWC") { | |||
| CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| @@ -99,16 +104,23 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<int> input_shape; | |||
| std::vector<size_t> input_shape; | |||
| GetInputShape(kernel_node, &input_shape); | |||
| if (data_format_ == "NHWC") { | |||
| compute_format_ = CUDNN_TENSOR_NHWC; | |||
| ShapeNCHW2NHWC(&input_shape); | |||
| } | |||
| SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||
| Set4DDesc(dy_shape, input_shape, filter_shape); | |||
| group_ = GetAttr<int>(kernel_node, "group"); | |||
| @@ -121,17 +133,53 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| cudnnTensorDescriptor_t dx_desc_real = nullptr; | |||
| int padA[2]; | |||
| int strideA[2] = {stride_[0], stride_[1]}; | |||
| int dilaA[2] = {dilation_[0], dilation_[1]}; | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { | |||
| SetPad(input_shape, kernel_node); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| int dimA[4]; | |||
| int strideApadded[4]; | |||
| if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), | |||
| IntToSize(old_width_ + pad_width_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } else if (data_format_ == "NHWC") { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), | |||
| IntToSize(c_)}; | |||
| SetDimA(padded_shape, dimA, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, data_format_); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| if (use_pad_) { | |||
| padA[0] = 0; | |||
| padA[1] = 0; | |||
| } else { | |||
| padA[0] = pad_top_; | |||
| padA[1] = pad_left_; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| padA[0] = pad_height_; | |||
| padA[1] = pad_width_; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], | |||
| dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| dx_desc_real = dx_desc_; | |||
| } | |||
| @@ -208,24 +256,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| } | |||
| void SetPad(const std::vector<int> &input_shape, const CNodePtr &kernel_node) { | |||
| auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list"); | |||
| n_ = input_shape[0]; | |||
| c_ = input_shape[1]; | |||
| old_height_ = input_shape[2]; | |||
| old_width_ = input_shape[3]; | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, | |||
| c_, old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( | |||
| conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], | |||
| dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| } | |||
| void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { | |||
| if (group_ > 1 || CUDNN_MAJOR < 7) { | |||
| @@ -247,25 +277,32 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; | |||
| } | |||
| } | |||
| void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) { | |||
| void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) { | |||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), | |||
| [](const ValuePtr &e) -> int { return e->cast<Int32ImmPtr>()->value(); }); | |||
| [](const ValuePtr &e) -> size_t { return e->cast<Int32ImmPtr>()->value(); }); | |||
| } | |||
| void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<int> &input_shape, | |||
| void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape, | |||
| const std::vector<size_t> &filter_shape) { | |||
| int nbDims = 4; | |||
| int dimA[4]; | |||
| int strideAin[4]; | |||
| int dimAdy[4]; | |||
| int strideAdy[4]; | |||
| int filterDimA[4]; | |||
| SetDimA(input_shape, dimA, data_format_); | |||
| SetStrideA(input_shape, strideAin, data_format_); | |||
| SetDimA(dy_shape, dimAdy, data_format_); | |||
| SetStrideA(dy_shape, strideAdy, data_format_); | |||
| SetDimA(filter_shape, filterDimA, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), | |||
| SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), | |||
| "SetFilter4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), | |||
| SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), | |||
| "SetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], | |||
| input_shape[2], input_shape[3]), | |||
| "SetTensor4dDescriptor failed"); | |||
| cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), | |||
| "cudnnSetFilterNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, nbDims, dimA, strideAin), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| } | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride"); | |||
| @@ -288,10 +325,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnConvolutionBwdDataAlgo_t algo_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = "NCHW"; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| @@ -35,9 +35,9 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| input_descriptor_(nullptr), | |||
| output_descriptor_(nullptr), | |||
| pooling_descriptor_(nullptr), | |||
| padded_descriptor_(nullptr), | |||
| pooling_mode_(CUDNN_POOLING_MAX), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| @@ -50,9 +50,7 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| padded_size_(0), | |||
| workspace_size_(0), | |||
| use_pad_(true) {} | |||
| workspace_size_(0) {} | |||
| ~PoolingGpuFwdKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -67,20 +65,10 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| T *output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded_addr = reinterpret_cast<T *>(workspace[0]->addr); | |||
| CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, | |||
| padded_addr, &beta, output_descriptor_, output_addr), | |||
| "cudnnPoolingForward failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, | |||
| input_addr, &beta, output_descriptor_, output_addr), | |||
| "cudnnPoolingForward failed"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, | |||
| input_addr, &beta, output_descriptor_, output_addr), | |||
| "cudnnPoolingForward failed"); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| @@ -89,39 +77,64 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||
| int nbDims = 4; | |||
| int dimA[4]; | |||
| int strideAin[4]; | |||
| int dimAout[4]; | |||
| int strideAout[4]; | |||
| SetDimA(input_shape, dimA, data_format_); | |||
| SetStrideA(input_shape, strideAin, data_format_); | |||
| SetDimA(output_shape, dimAout, data_format_); | |||
| SetStrideA(output_shape, strideAout, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), | |||
| SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), | |||
| cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), | |||
| SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), | |||
| cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); | |||
| int window_height = window[2]; | |||
| int window_width = window[3]; | |||
| stride_ = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); | |||
| SetPoolingMode(kernel_node); | |||
| int windowDimA[2] = {window_height, window_width}; | |||
| int paddingA[2] = {0, 0}; | |||
| int strideA[2] = {stride_[2], stride_[3]}; | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { | |||
| SetPad(input_shape, window_height, window_width); | |||
| pad_height_ = | |||
| std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) | |||
| : (old_height_ / stride_[2]) + 1) - | |||
| 1) * | |||
| stride_[2] + | |||
| window_height - old_height_); | |||
| pad_width_ = | |||
| std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) | |||
| : (old_width_ / stride_[3]) + 1) - | |||
| 1) * | |||
| stride_[3] + | |||
| window_width - old_width_); | |||
| pad_top_ = pad_height_ / 2; | |||
| pad_left_ = pad_width_ / 2; | |||
| paddingA[0] = pad_top_; | |||
| paddingA[1] = pad_left_; | |||
| } else { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, | |||
| window_width, pad_height_, pad_width_, stride_[2], stride_[3]), | |||
| "cudnnSetPooling2dDescriptor failed"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, | |||
| 2, windowDimA, paddingA, strideA), | |||
| "cudnnSetPoolingNdDescriptor failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -131,7 +144,6 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), | |||
| "cudnnCreatePoolingDescriptor failed"); | |||
| } | |||
| @@ -146,15 +158,6 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| } | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast<size_t *>(&padded_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| workspace_size_list_.push_back(padded_size_); | |||
| if (padded_size_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Padded size is 0."; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| @@ -167,36 +170,7 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| } | |||
| return true; | |||
| } | |||
| void SetPad(const std::vector<size_t> &input_shape, const int &window_height, const int &window_width) { | |||
| n_ = SizeToInt(input_shape[0]); | |||
| c_ = SizeToInt(input_shape[1]); | |||
| old_height_ = SizeToInt(input_shape[2]); | |||
| old_width_ = SizeToInt(input_shape[3]); | |||
| pad_height_ = | |||
| std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) | |||
| : (old_height_ / stride_[2]) + 1) - | |||
| 1) * | |||
| stride_[2] + | |||
| window_height - old_height_); | |||
| pad_width_ = | |||
| std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) | |||
| : (old_width_ / stride_[3]) + 1) - | |||
| 1) * | |||
| stride_[3] + | |||
| window_width - old_width_); | |||
| pad_top_ = pad_height_ / 2; | |||
| pad_left_ = pad_width_ / 2; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, | |||
| c_, old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, | |||
| window_height, window_width, use_pad_ ? 0 : pad_top_, | |||
| use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), | |||
| "cudnnSetPooling2dDescriptor failed"); | |||
| } | |||
| void SetPoolingMode(const CNodePtr &kernel_node) { | |||
| pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); | |||
| mode_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -211,7 +185,6 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), | |||
| "cudnnDestroyPoolingDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| } | |||
| @@ -220,16 +193,16 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| cudnnTensorDescriptor_t input_descriptor_; | |||
| cudnnTensorDescriptor_t output_descriptor_; | |||
| cudnnPoolingDescriptor_t pooling_descriptor_; | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; | |||
| std::vector<int> stride_; | |||
| std::string mode_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = "NCHW"; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| @@ -242,9 +215,7 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||
| bool is_null_input_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -37,9 +37,9 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| dy_descriptor_(nullptr), | |||
| x_descriptor_(nullptr), | |||
| dx_descriptor_(nullptr), | |||
| padded_descriptor_(nullptr), | |||
| pooling_mode_(CUDNN_POOLING_MAX), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| compute_format_(CUDNN_TENSOR_NCHW), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| @@ -52,9 +52,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| padded_size_(0), | |||
| workspace_size_(0), | |||
| use_pad_(true) {} | |||
| workspace_size_(0) {} | |||
| ~PoolingGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -72,27 +70,10 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 0); | |||
| T *padded_dx = GetDeviceAddress<T>(workspace, 1); | |||
| CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, | |||
| padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), | |||
| "cudnnPoolingBackward failed"); | |||
| CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, | |||
| x_descriptor_, x_data, &beta, dx_descriptor_, dx), | |||
| "cudnnPoolingBackward failed"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, | |||
| x_descriptor_, x_data, &beta, dx_descriptor_, dx), | |||
| "cudnnPoolingBackward failed"); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -104,46 +85,73 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| int window_height = window[2]; | |||
| int window_width = window[3]; | |||
| SetPoolingMode(kernel_node); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); | |||
| int windowDimA[2] = {window_height, window_width}; | |||
| int paddingA[2] = {0, 0}; | |||
| int strideA[2] = {stride_[2], stride_[3]}; | |||
| int nbDims = 4; | |||
| int dimA[4]; | |||
| int strideAin[4]; | |||
| int dimAy[4]; | |||
| int strideAiny[4]; | |||
| int dimAdy[4]; | |||
| int strideAdy[4]; | |||
| int dimAout[4]; | |||
| int strideAout[4]; | |||
| SetDimA(input_shape, dimA, data_format_); | |||
| SetStrideA(input_shape, strideAin, data_format_); | |||
| SetDimA(input_mask, dimAy, data_format_); | |||
| SetStrideA(input_mask, strideAiny, data_format_); | |||
| SetDimA(dout_shape, dimAdy, data_format_); | |||
| SetStrideA(dout_shape, strideAdy, data_format_); | |||
| SetDimA(output_shape, dimAout, data_format_); | |||
| SetStrideA(output_shape, strideAout, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), | |||
| SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), | |||
| "cudnnSetTensor4dDescriptor"); | |||
| auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), | |||
| SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), | |||
| "cudnnSetTensor4dDescriptor"); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), | |||
| SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), | |||
| cudnnSetTensorNdDescriptor(dx_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { | |||
| SetPad(input_shape, window_height, window_width); | |||
| pad_height_ = | |||
| std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) | |||
| : (old_height_ / stride_[2]) + 1) - | |||
| 1) * | |||
| stride_[2] + | |||
| window_height - old_height_); | |||
| pad_width_ = | |||
| std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) | |||
| : (old_width_ / stride_[3]) + 1) - | |||
| 1) * | |||
| stride_[3] + | |||
| window_width - old_width_); | |||
| pad_top_ = pad_height_ / 2; | |||
| pad_left_ = pad_width_ / 2; | |||
| paddingA[0] = pad_top_; | |||
| paddingA[1] = pad_left_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, | |||
| window_width, pad_height_, pad_width_, stride_[2], stride_[3]), | |||
| "cudnnSetPooling2dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), | |||
| SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), | |||
| "cudnnSetTensor4dDescriptor"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, | |||
| 2, windowDimA, paddingA, strideA), | |||
| "cudnnSetPoolingNdDescriptor failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -155,7 +163,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), | |||
| "cudnnCreatePoolingDescriptor failed"); | |||
| } | |||
| @@ -179,16 +186,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| } | |||
| input_size_list_.push_back(input_size_); | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| if (padded_size_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Padded size is 0."; | |||
| } | |||
| workspace_size_list_.push_back(padded_size_); | |||
| workspace_size_list_.push_back(padded_size_); | |||
| } | |||
| return; | |||
| } | |||
| @@ -206,35 +203,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| c_ = SizeToInt(input_shape[1]); | |||
| old_height_ = SizeToInt(input_shape[2]); | |||
| old_width_ = SizeToInt(input_shape[3]); | |||
| pad_height_ = | |||
| std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) | |||
| : (old_height_ / stride_[2]) + 1) - | |||
| 1) * | |||
| stride_[2] + | |||
| window_height - old_height_); | |||
| pad_width_ = | |||
| std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) | |||
| : (old_width_ / stride_[3]) + 1) - | |||
| 1) * | |||
| stride_[3] + | |||
| window_width - old_width_); | |||
| pad_top_ = pad_height_ / 2; | |||
| pad_left_ = pad_width_ / 2; | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, | |||
| c_, old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), | |||
| SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), | |||
| SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), | |||
| "cudnnSetTensor4dDescriptor"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, | |||
| window_height, window_width, use_pad_ ? 0 : pad_top_, | |||
| use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), | |||
| "cudnnSetPooling2dDescriptor failed"); | |||
| } | |||
| void SetPoolingMode(const CNodePtr &kernel_node) { | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "padding"); | |||
| @@ -252,7 +220,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), | |||
| "cudnnDestroyPoolingDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| @@ -265,7 +232,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| cudnnTensorDescriptor_t dy_descriptor_; | |||
| cudnnTensorDescriptor_t x_descriptor_; | |||
| cudnnTensorDescriptor_t dx_descriptor_; | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; | |||
| std::vector<int> stride_; | |||
| std::vector<size_t> input_size_list_; | |||
| @@ -273,7 +239,9 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::string mode_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = "NCHW"; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| @@ -286,9 +254,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||
| bool is_null_input_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -45,23 +45,23 @@ def test_maxpool2d_grad(): | |||
| [24, 25, 26, 27, 28, 29], | |||
| [30, 31, 32, 33, 34, 35] | |||
| ]]]).astype(np.float32)) | |||
| a = Tensor(np.array([[[ | |||
| d = Tensor(np.array([[[ | |||
| [3, 3, 3], | |||
| [3, 3, 3], | |||
| [3, 3, 3] | |||
| ]]]).astype(np.float32)) | |||
| d = Tensor(np.array([[[ | |||
| a = Tensor(np.array([[[ | |||
| [7, 9, 11], | |||
| [19, 21, 23], | |||
| [31, 33, 35] | |||
| ]]]).astype(np.float32)) | |||
| expect_result = (np.array([[[ | |||
| [0, 0, 0, 0, 0, 0], | |||
| [0, 7, 0, 9, 0, 11], | |||
| [0, 3, 0, 3, 0, 3], | |||
| [0, 0, 0, 0, 0, 0], | |||
| [0, 19, 0, 21, 0, 23], | |||
| [0, 3, 0, 3, 0, 3], | |||
| [0, 0, 0, 0, 0, 0], | |||
| [0, 31, 0, 33, 0, 35] | |||
| [0, 3, 0, 3, 0, 3] | |||
| ]]])) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||