From: @tom__chen Reviewed-by: @robingrosman,@mikef Signed-off-by:pull/16022/MERGE
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -129,7 +129,7 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in | |||
| const int pos_w = pos % padded_width; | |||
| const int block_num = pos / padded_dhw; | |||
| if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_h - pad_head >= old_depth || | |||
| if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_d - pad_head >= old_depth || | |||
| pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) { | |||
| output[pos] = pad_value_; | |||
| } else { | |||
| @@ -140,6 +140,23 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void PadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth, | |||
| const int old_height, const int old_width, const int old_dhw, const int old_hw, | |||
| const int padded_depth, const int padded_height, const int padded_width, | |||
| const int padded_dhw, const int padded_hw, const int pad_head, const int pad_top, | |||
| const int pad_left, T* dx) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||
| const int block_num = pos / old_dhw; | |||
| const int pos_d = pos / old_hw % old_depth + pad_head; | |||
| const int pos_h = pos / old_width % old_height + pad_top; | |||
| const int pos_w = pos % old_width + pad_left; | |||
| const int index = block_num * padded_dhw + pos_d * padded_hw + pos_h * padded_width + pos_w; | |||
| dx[pos] = dy[index]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, | |||
| const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, | |||
| @@ -204,6 +221,22 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth, | |||
| const int old_height, const int old_width, const int padded_depth, const int padded_height, | |||
| const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx, | |||
| cudaStream_t cuda_stream) { | |||
| const int old_hw = old_height * old_width; | |||
| const int old_dhw = old_depth * old_hw; | |||
| const int padded_hw = padded_height * padded_width; | |||
| const int padded_dhw = padded_depth * padded_hw; | |||
| PadGrad3d<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, channels, old_depth, old_height, | |||
| old_width, old_dhw, old_hw, padded_depth, padded_height, | |||
| padded_width, padded_dhw, padded_hw, pad_head, pad_top, | |||
| pad_left, dx); | |||
| return; | |||
| } | |||
| template void CalPad<float>(const size_t size, const float* input, const int num, const int channels, | |||
| const int old_height, const int old_width, const int padded_height, const int padded_width, | |||
| const int pad_top, const int pad_left, float pad_value, float* output, | |||
| @@ -259,3 +292,13 @@ template void CalPad3d<half>(const size_t size, const half* input, const int num | |||
| const int old_depth, const int old_height, const int old_width, const int padded_depth, | |||
| const int padded_height, const int padded_width, const int pad_head, const int pad_top, | |||
| const int pad_left, const float pad_value, half* output, cudaStream_t cuda_stream); | |||
| template void CalPadGrad3d<float>(const size_t size, const float* dy, const int num, const int channels, | |||
| const int old_depth, const int old_height, const int old_width, | |||
| const int padded_depth, const int padded_height, const int padded_width, | |||
| const int pad_head, const int pad_top, const int pad_left, float* dx, | |||
| cudaStream_t cuda_stream); | |||
| template void CalPadGrad3d<half>(const size_t size, const half* dy, const int num, const int channels, | |||
| const int old_depth, const int old_height, const int old_width, | |||
| const int padded_depth, const int padded_height, const int padded_width, | |||
| const int pad_head, const int pad_top, const int pad_left, half* dx, | |||
| cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -45,4 +45,9 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe | |||
| const int old_height, const int old_width, const int padded_depth, const int padded_height, | |||
| const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value, | |||
| T* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth, | |||
| const int old_height, const int old_width, const int padded_depth, const int padded_height, | |||
| const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Conv3DBackpropFilter, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv3dGradFilterGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Conv3DBackpropFilter, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv3dGradFilterGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,418 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class Conv3dGradFilterGpuKernel : public GpuKernel { | |||
| public: | |||
| Conv3dGradFilterGpuKernel() { ResetResource(); } | |||
| ~Conv3dGradFilterGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *x = GetDeviceAddress<T>(inputs, 0); | |||
| T *dy = GetDeviceAddress<T>(inputs, 1); | |||
| T *work_space = nullptr; | |||
| if (workspace_size_ != 0) { | |||
| work_space = GetDeviceAddress<T>(workspace, 0); | |||
| } | |||
| T *dw = nullptr; | |||
| float *dw_float32 = nullptr; | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| dw = GetDeviceAddress<T>(workspace, 1); | |||
| dw_float32 = GetDeviceAddress<float>(outputs, 0); | |||
| } else { | |||
| dw = GetDeviceAddress<T>(outputs, 0); | |||
| } | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if (use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 1); | |||
| CalPad3d(padded_size_ / sizeof(T), x, n_, c_, old_depth_, old_height_, old_width_, old_depth_ + pad_depth_, | |||
| old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, pad_left_, pad_value_, padded, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, | |||
| algo_, work_space, workspace_size_, &beta, dw_desc_, dw), | |||
| "ConvolutionBackwardFilter failed"); | |||
| return true; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta, dw_desc_, dw), | |||
| "ConvolutionBackwardFilter failed"); | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| Cast(num_output_elements_, dw, dw_float32, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| InitResource(); | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "Conv3dGradFilterGpuKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| CHECK_TENSOR_SIZE(in_shape); | |||
| data_format_ = kOpFormat_NCDHW; | |||
| std::vector<size_t> filter_shape; | |||
| GetFilterShape(kernel_node, &filter_shape); | |||
| num_output_elements_ = 1; | |||
| for (auto x : filter_shape) { | |||
| num_output_elements_ *= x; | |||
| } | |||
| compute_format_ = CUDNN_TENSOR_NCHW; | |||
| n_ = SizeToInt(in_shape[0]); | |||
| c_ = SizeToInt(in_shape[1]); | |||
| old_depth_ = SizeToInt(in_shape[2]); | |||
| old_height_ = SizeToInt(in_shape[3]); | |||
| old_width_ = SizeToInt(in_shape[3]); | |||
| SetNDDesc(dy_shape, filter_shape, in_shape); | |||
| group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group")); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_), | |||
| "cudnnSetConvGroupCount failed"); | |||
| std::vector<int> pad_list; | |||
| std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list"); | |||
| (void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| pad_depth_ = pad_list[0]; | |||
| pad_height_ = pad_list[2]; | |||
| pad_width_ = pad_list[4]; | |||
| use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5])); | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| cudnnTensorDescriptor_t x_desc_real = nullptr; | |||
| const int kNumDims = 5; | |||
| const int kConvDims = 3; | |||
| int padA[kConvDims]; | |||
| int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]}; | |||
| int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]}; | |||
| if (use_pad_) { | |||
| pad_depth_ = pad_list[0] + pad_list[1]; | |||
| pad_height_ = pad_list[2] + pad_list[3]; | |||
| pad_width_ = pad_list[4] + pad_list[5]; | |||
| pad_head_ = pad_list[0]; | |||
| pad_top_ = pad_list[2]; | |||
| pad_left_ = pad_list[4]; | |||
| int dimA[kNumDims]; | |||
| int strideApadded[kNumDims]; | |||
| if (data_format_ == kOpFormat_NCDHW) { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_), | |||
| IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; | |||
| SetDimA(padded_shape, dimA, kNumDims, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, kNumDims, data_format_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel only support NCDHW format right now."; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| padA[0] = 0; | |||
| padA[1] = 0; | |||
| padA[2] = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, | |||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| x_desc_real = padded_descriptor_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_depth_ = 0; | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| padA[0] = pad_depth_; | |||
| padA[1] = pad_height_; | |||
| padA[2] = pad_width_; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, | |||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| x_desc_real = x_desc_; | |||
| } | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||
| "cudnnSetConvolutionMathType failed.") | |||
| } | |||
| SelectAlgorithm(x_desc_real); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| cudnn_handle_ = nullptr; | |||
| dw_desc_ = nullptr; | |||
| conv_desc_ = nullptr; | |||
| dy_desc_ = nullptr; | |||
| x_desc_ = nullptr; | |||
| padded_descriptor_ = nullptr; | |||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||
| compute_format_ = CUDNN_TENSOR_NCHW; | |||
| old_depth_ = 0; | |||
| old_height_ = 0; | |||
| old_width_ = 0; | |||
| pad_depth_ = 0; | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| pad_head_ = 0; | |||
| pad_top_ = 0; | |||
| pad_left_ = 0; | |||
| n_ = 0; | |||
| c_ = 0; | |||
| group_ = 1; | |||
| is_null_input_ = false; | |||
| input_size_ = 0; | |||
| dy_size_ = 0; | |||
| output_size_ = 0; | |||
| padded_size_ = 0; | |||
| workspace_size_ = 0; | |||
| use_pad_ = true; | |||
| num_output_elements_ = 1; | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_), | |||
| "cudnnDestroyConvolutionDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(dw_desc_), | |||
| "cudnnDestroyFilterDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&dw_desc_), | |||
| "cudnnCreateFilterDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_), | |||
| "cudnnCreateConvolutionDescriptor failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast<size_t *>(&dy_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast<size_t *>(&input_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast<size_t *>(&output_size_)), | |||
| "cudnnGetFilterSizeInBytes failed"); | |||
| } | |||
| input_size_list_.push_back(dy_size_); | |||
| input_size_list_.push_back(input_size_); | |||
| if (use_pad_ && !is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast<size_t *>(&padded_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_, | |||
| dw_desc_, algo_, reinterpret_cast<size_t *>(&workspace_size_)), | |||
| "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); | |||
| workspace_size_list_.push_back(padded_size_); | |||
| } else { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_, | |||
| reinterpret_cast<size_t *>(&workspace_size_)), | |||
| "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); | |||
| } | |||
| } | |||
| (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| workspace_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(num_output_elements_ * sizeof(float)); | |||
| } else { | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| } | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradFilterGpuKernel needs 2 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradFilterGpuKernel needs 1 output."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { | |||
| const int requested_algo_count = 1; | |||
| int returned_algo_count = 0; | |||
| cudnnConvolutionBwdFilterAlgoPerf_t perf_results; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, | |||
| requested_algo_count, &returned_algo_count, &perf_results), | |||
| "GetConvolutionBackwardFilterAlgorithm failed"); | |||
| algo_ = perf_results.algo; | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; | |||
| } | |||
| } | |||
| void GetFilterShape(const CNodePtr &kernel_node, std::vector<size_t> *filter_shape) { | |||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_size")->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), | |||
| [](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); }); | |||
| } | |||
| void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape, | |||
| const std::vector<size_t> &in_shape) { | |||
| const int kDims = 5; | |||
| int dimA[kDims]; | |||
| int strideAin[kDims]; | |||
| int dimAdy[kDims]; | |||
| int strideAdy[kDims]; | |||
| int filterDimA[kDims]; | |||
| SetDimA(in_shape, dimA, kDims, data_format_); | |||
| SetStrideA(in_shape, strideAin, kDims, data_format_); | |||
| SetDimA(dy_shape, dimAdy, kDims, data_format_); | |||
| SetStrideA(dy_shape, strideAdy, kDims, data_format_); | |||
| SetDimA(filter_shape, filterDimA, kDims, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA), | |||
| "cudnnSetFilterNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, kDims, dimA, strideAin), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| } | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides"); | |||
| std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations"); | |||
| (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| if (stride_.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride must be 5d, but got " << stride_.size(); | |||
| } | |||
| if (stride_[0] != 1 || stride_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride only support 1 in N axis and C axis!"; | |||
| } | |||
| if (dilation_.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation must be 5d!"; | |||
| } | |||
| if (dilation_[0] != 1 || dilation_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation only support 1 in N axis and C axis!"; | |||
| } | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnFilterDescriptor_t dw_desc_; | |||
| cudnnConvolutionDescriptor_t conv_desc_; | |||
| cudnnTensorDescriptor_t dy_desc_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnConvolutionBwdFilterAlgo_t algo_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = kOpFormat_NCDHW; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| const float pad_value_ = 0.0; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_depth_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_depth_; | |||
| int pad_height_; | |||
| int pad_width_; | |||
| int pad_head_; | |||
| int pad_top_; | |||
| int pad_left_; | |||
| int n_; | |||
| int c_; | |||
| std::vector<int> stride_; | |||
| std::vector<int> dilation_; | |||
| int group_; | |||
| bool is_null_input_; | |||
| size_t input_size_; | |||
| size_t dy_size_; | |||
| size_t output_size_; | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| size_t num_output_elements_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Conv3DBackpropInput, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Conv3dGradInputGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Conv3DBackpropInput, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| Conv3dGradInputGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,397 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class Conv3dGradInputGpuKernel : public GpuKernel { | |||
| public: | |||
| Conv3dGradInputGpuKernel() { ResetResource(); } | |||
| ~Conv3dGradInputGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *w = GetDeviceAddress<T>(inputs, 0); | |||
| T *dy = GetDeviceAddress<T>(inputs, 1); | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| T *work_space = nullptr; | |||
| if (workspace_size_ != 0) { | |||
| work_space = GetDeviceAddress<T>(workspace, 0); | |||
| } | |||
| const float alpha = 1; | |||
| if (use_pad_) { | |||
| T *padded = GetDeviceAddress<T>(workspace, 1); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta_, padded_descriptor_, padded), | |||
| "ConvolutionBackwardData failed"); | |||
| CalPadGrad3d(output_size_ / sizeof(T), padded, n_, c_, old_depth_, old_height_, old_width_, | |||
| old_depth_ + pad_depth_, old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, | |||
| pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, | |||
| workspace_size_, &beta_, dx_desc_, dx), | |||
| "ConvolutionBackwardData failed"); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| InitResource(); | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| data_format_ = kOpFormat_NCDHW; | |||
| auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "Conv3dGradInputGpuKernel input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<size_t> input_shape; | |||
| GetInputShape(kernel_node, &input_shape); | |||
| compute_format_ = CUDNN_TENSOR_NCHW; | |||
| CHECK_TENSOR_SIZE(input_shape); | |||
| n_ = SizeToInt(input_shape[0]); | |||
| c_ = SizeToInt(input_shape[1]); | |||
| old_depth_ = SizeToInt(input_shape[2]); | |||
| old_height_ = SizeToInt(input_shape[3]); | |||
| old_width_ = SizeToInt(input_shape[3]); | |||
| SetNDDesc(dy_shape, input_shape, filter_shape); | |||
| group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group")); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_), | |||
| "cudnnSetConvGroupCount failed"); | |||
| std::vector<int> pad_list; | |||
| std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list"); | |||
| (void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| pad_depth_ = pad_list[0]; | |||
| pad_height_ = pad_list[2]; | |||
| pad_width_ = pad_list[4]; | |||
| use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5])); | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| cudnnTensorDescriptor_t dx_desc_real = nullptr; | |||
| const int kNumDims = 5; | |||
| const int kConvDims = 3; | |||
| int padA[kConvDims]; | |||
| int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]}; | |||
| int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]}; | |||
| if (use_pad_) { | |||
| pad_depth_ = pad_list[0] + pad_list[1]; | |||
| pad_height_ = pad_list[2] + pad_list[3]; | |||
| pad_width_ = pad_list[4] + pad_list[5]; | |||
| pad_head_ = pad_list[0]; | |||
| pad_top_ = pad_list[2]; | |||
| pad_left_ = pad_list[4]; | |||
| int dimA[kNumDims]; | |||
| int strideApadded[kNumDims]; | |||
| if (data_format_ == kOpFormat_NCDHW) { | |||
| auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_), | |||
| IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; | |||
| SetDimA(padded_shape, dimA, kNumDims, data_format_); | |||
| SetStrideA(padded_shape, strideApadded, kNumDims, data_format_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel only support NCDHW format right now."; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| padA[0] = 0; | |||
| padA[1] = 0; | |||
| padA[2] = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, | |||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolutionNdDescriptor failed"); | |||
| dx_desc_real = padded_descriptor_; | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_depth_ = 0; | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| padA[0] = pad_depth_; | |||
| padA[1] = pad_height_; | |||
| padA[2] = pad_width_; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, | |||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| dx_desc_real = dx_desc_; | |||
| } | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||
| "cudnnSetConvolutionMathType failed.") | |||
| } | |||
| SelectAlgorithm(dx_desc_real); | |||
| beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| cudnn_handle_ = nullptr; | |||
| w_desc_ = nullptr; | |||
| conv_desc_ = nullptr; | |||
| dy_desc_ = nullptr; | |||
| dx_desc_ = nullptr; | |||
| padded_descriptor_ = nullptr; | |||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||
| compute_format_ = CUDNN_TENSOR_NCHW; | |||
| old_depth_ = 0; | |||
| old_height_ = 0; | |||
| old_width_ = 0; | |||
| pad_depth_ = 0; | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| pad_head_ = 0; | |||
| pad_top_ = 0; | |||
| pad_left_ = 0; | |||
| n_ = 0; | |||
| c_ = 0; | |||
| group_ = 1; | |||
| is_null_input_ = false; | |||
| dy_size_ = 0; | |||
| w_size_ = 0; | |||
| output_size_ = 0; | |||
| padded_size_ = 0; | |||
| workspace_size_ = 0; | |||
| use_pad_ = true; | |||
| beta_ = 0; | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_), | |||
| "cudnnDestroyConvolutionDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(w_desc_), | |||
| "cudnnDestroyFilterDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), | |||
| "cudnnDestroyTensorDescriptor failed"); | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_), | |||
| "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&w_desc_), | |||
| "cudnnCreateFilterDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_), | |||
| "cudnnCreateConvolutionDescriptor failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetFilterSizeInBytes(w_desc_, &w_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dx_desc_, &output_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| } | |||
| input_size_list_.push_back(dy_size_); | |||
| input_size_list_.push_back(w_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| if (use_pad_ && !is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_, | |||
| algo_, &workspace_size_), | |||
| "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); | |||
| workspace_size_list_.push_back(padded_size_); | |||
| } else { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
| cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_), | |||
| "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); | |||
| } | |||
| } | |||
| (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); | |||
| } | |||
| private: | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradInputGpuKernel needs 2 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradInputGpuKernel needs 1 output."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void SetPad(const std::vector<int> &input_shape, const CNodePtr &kernel_node) { | |||
| std::vector<int> pad_list; | |||
| std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list"); | |||
| (void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| } | |||
| void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { | |||
| const int requested_algo_count = 1; | |||
| int returned_algo_count = 0; | |||
| cudnnConvolutionBwdDataAlgoPerf_t perf_results; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, | |||
| requested_algo_count, &returned_algo_count, &perf_results), | |||
| "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); | |||
| algo_ = perf_results.algo; | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; | |||
| } | |||
| } | |||
| void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) { | |||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_size")->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), | |||
| [](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); }); | |||
| } | |||
| void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape, | |||
| const std::vector<size_t> &filter_shape) { | |||
| const int kDims = 5; | |||
| int dimA[kDims]; | |||
| int strideAin[kDims]; | |||
| int dimAdy[kDims]; | |||
| int strideAdy[kDims]; | |||
| int filterDimA[kDims]; | |||
| SetDimA(input_shape, dimA, kDims, data_format_); | |||
| SetStrideA(input_shape, strideAin, kDims, data_format_); | |||
| SetDimA(dy_shape, dimAdy, kDims, data_format_); | |||
| SetStrideA(dy_shape, strideAdy, kDims, data_format_); | |||
| SetDimA(filter_shape, filterDimA, kDims, data_format_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA), | |||
| "cudnnSetFilterNdDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, kDims, dimA, strideAin), | |||
| "cudnnSetTensorNdDescriptor failed"); | |||
| } | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides"); | |||
| std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations"); | |||
| (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| if (stride_.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride must be 5d, but got " << stride_.size(); | |||
| } | |||
| if (stride_[0] != 1 || stride_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride only support 1 in N axis and C axis!"; | |||
| } | |||
| if (dilation_.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation must be 5d!"; | |||
| } | |||
| if (dilation_[0] != 1 || dilation_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation only support 1 in N axis and C axis!"; | |||
| } | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnFilterDescriptor_t w_desc_; | |||
| cudnnConvolutionDescriptor_t conv_desc_; | |||
| cudnnTensorDescriptor_t dy_desc_; | |||
| cudnnTensorDescriptor_t dx_desc_; | |||
| cudnnTensorDescriptor_t padded_descriptor_; | |||
| cudnnConvolutionBwdDataAlgo_t algo_; | |||
| std::string pad_mode_; | |||
| std::string data_format_ = kOpFormat_NCDHW; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| cudnnTensorFormat_t compute_format_; | |||
| int old_depth_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_depth_; | |||
| int pad_height_; | |||
| int pad_width_; | |||
| int pad_head_; | |||
| int pad_top_; | |||
| int pad_left_; | |||
| int n_; | |||
| int c_; | |||
| std::vector<int> stride_; | |||
| std::vector<int> dilation_; | |||
| int group_; | |||
| bool is_null_input_; | |||
| size_t dy_size_; | |||
| size_t w_size_; | |||
| size_t output_size_; | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| float beta_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_ | |||
| @@ -58,8 +58,8 @@ class _Conv(Cell): | |||
| self.format = Validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.cls_name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| if context.get_context("device_target") != "Ascend" and self.format == "NCDHW": | |||
| raise ValueError("NCDHW format only support in Ascend target.") | |||
| if context.get_context("device_target") == "CPU" and self.format == "NCDHW": | |||
| raise ValueError("NCDHW format only support in Ascend and GPU targets.") | |||
| if isinstance(padding, int): | |||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||
| self.padding = padding | |||
| @@ -19,7 +19,9 @@ import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| class NetConv3d(nn.Cell): | |||
| @@ -71,3 +73,84 @@ def test_conv3d(): | |||
| net = NetConv3d() | |||
| output = net(x, w) | |||
| assert (output.asnumpy() == expect).all() | |||
| class MSConv3dNet(nn.Cell): | |||
| def __init__(self, in_channels, out_channels, kernel_size, pad_mode='pad', padding=0, stride=1, dilation=1, | |||
| has_bias=False, weight_init='normal'): | |||
| super(MSConv3dNet, self).__init__() | |||
| self.cv1 = nn.Conv3d(in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| pad_mode=pad_mode, | |||
| padding=padding, | |||
| stride=stride, | |||
| dilation=dilation, | |||
| group=1, | |||
| has_bias=has_bias, | |||
| weight_init=weight_init, | |||
| data_format='NCDHW') | |||
| def construct(self, x): | |||
| x = self.cv1(x) | |||
| return x | |||
| class MSGradNet(nn.Cell): | |||
| def __init__(self, network): | |||
| super(MSGradNet, self).__init__() | |||
| self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True) | |||
| self.network = network | |||
| self.params = ParameterTuple(network.trainable_params()) | |||
| def construct(self, x, dy): | |||
| grad_op = self.grad(self.network, self.params) | |||
| output = grad_op(x, dy) | |||
| return output | |||
| def test_conv3d_grad(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dtype = np.float32 | |||
| out_c = 2 | |||
| kernel_size = (2, 2, 2) | |||
| x = Tensor(np.array([[[[[1.6924546, 0.05080776, -0.6369957], | |||
| [0.19091548, 2.1002553, 0.12015896], | |||
| [0.6172031, 0.30017033, -0.35224986]], | |||
| [[-1.1425182, -0.34934273, -0.20889424], | |||
| [0.5866232, 0.8389834, 0.9311021], | |||
| [0.2855873, 0.8851412, -0.7543979]], | |||
| [[1.2528682, 0.5129298, -0.29809284], | |||
| [0.48851815, -0.07557172, 1.1316293], | |||
| [1.5198169, 2.1855755, -1.3964963]]]]]).astype(dtype)) | |||
| dy = Tensor(np.array([[[[[-1.4441139, -0.5044659], | |||
| [0.16003707, 0.8761689]], | |||
| [[0.31563494, -2.0222013], | |||
| [-0.30620402, 0.8279746]]], | |||
| [[[0.23009473, 0.7620112], | |||
| [-0.22232814, -0.20075807]], | |||
| [[0.18656139, 0.41005164], | |||
| [0.19829972, 0.11900865]]]]]).astype(dtype)) | |||
| w = Tensor(np.array([[[[[-0.9358, -0.2679], | |||
| [0.5304, -0.6917]], | |||
| [[-0.3968, -0.6872], | |||
| [-0.8452, -0.6712]]]], | |||
| [[[[-0.0127, -1.1173], | |||
| [0.2344, 1.6598]], | |||
| [[0.7420, -0.1918], | |||
| [-0.8876, -0.7472]]]]]).astype(dtype)) | |||
| w_exp = np.array([[[[[-0.9384, -0.2830], | |||
| [0.5487, -0.6330]], | |||
| [[-0.4148, -0.7200], | |||
| [-0.8572, -0.6079]]]], | |||
| [[[[-0.0109, -1.1089], | |||
| [0.2138, 1.6478]], | |||
| [[0.7450, -0.1866], | |||
| [-0.8992, -0.7629]]]]]).astype(dtype) | |||
| net = MSConv3dNet(x.shape[1], out_c, kernel_size, weight_init=w) | |||
| grad_net = MSGradNet(net) | |||
| optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9) | |||
| grad_net.set_train(True) | |||
| output = grad_net(x, dy) | |||
| optimizer(output[1]) | |||
| assert np.allclose(net.cv1.weight.asnumpy(), w_exp, atol=1.0e-4) | |||