diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu index a9dbde7c42..d947adf33b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -67,8 +67,8 @@ __global__ void PadGeneral(const size_t size, const T *input, const int num, con const int pad_left, const T pad_value, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels) - const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos' - const int padded_h = (pos / padded_width) % padded_height; // y coordinate refered to by cur 'pos' + const int padded_w = pos % padded_width; // x coordinate referred to by cur 'pos' + const int padded_h = (pos / padded_width) % padded_height; // y coordinate referred to by cur 'pos' int channels_new = channels_orig + pad_channel_after + pad_channel_before; // new number of channels from padding int channel_num = block_num % channels_new; // current channel @@ -80,7 +80,7 @@ __global__ void PadGeneral(const size_t size, const T *input, const int num, con channel_num > channels_orig + pad_channel_before - 1) { output[pos] = pad_value; } else { - // on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input + // on a block/x,y position that isn't padding, copy data from the correct block/x,y pos the input // calculate from number of blocks of padding (due to channel padding) inserted prior equiv_block_num = block_num - (batch_item * (pad_channel_before + pad_channel_after)) - pad_channel_before; output[pos] = input[(equiv_block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; @@ -115,6 +115,31 @@ __global__ void PadGrad(const size_t size, const T* dy, const int num, const int return; } +// For internal OP use, not user facing +template +__global__ void Pad3d(const size_t size, const T* input, const int num, const int channels, const int old_depth, + const int old_height, const int old_width, const int old_dhw, const int old_hw, + const int padded_depth, const int padded_height, const int padded_width, const int padded_dhw, + const int padded_hw, const int pad_head, const int pad_top, const int pad_left, + const float pad_value, T* output) { + T pad_value_ = static_cast(pad_value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + const int pos_d = pos / padded_hw % padded_depth; + const int pos_h = pos / padded_width % padded_height; + const int pos_w = pos % padded_width; + const int block_num = pos / padded_dhw; + + if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_h - pad_head >= old_depth || + pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) { + output[pos] = pad_value_; + } else { + int index = block_num * old_dhw + old_hw * (pos_d - pad_head) + old_width * (pos_h - pad_top) + pos_w - pad_left; + output[pos] = input[index]; + } + } + return; +} + template void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, @@ -163,6 +188,22 @@ void CalPadGrad(const size_t size, const T* dy, const int num, const int channel return; } +template +void CalPad3d(const size_t size, const T* input, const int num, const int channels, const int old_depth, + const int old_height, const int old_width, const int padded_depth, const int padded_height, + const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value, + T* output, cudaStream_t cuda_stream) { + const int old_hw = old_height * old_width; + const int old_dhw = old_depth * old_hw; + const int padded_hw = padded_height * padded_width; + const int padded_dhw = padded_depth * padded_hw; + Pad3d<<>>(size, input, num, channels, old_depth, old_height, + old_width, old_dhw, old_hw, padded_depth, padded_height, + padded_width, padded_dhw, padded_hw, pad_head, pad_top, + pad_left, pad_value, output); + return; +} + template void CalPad(const size_t size, const float* input, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, float pad_value, float* output, @@ -210,3 +251,11 @@ template void CalPadGeneral(const size_t size, const int *input, const int const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, const int pad_value, int *output, cudaStream_t cuda_stream); +template void CalPad3d(const size_t size, const float* input, const int num, const int channels, + const int old_depth, const int old_height, const int old_width, const int padded_depth, + const int padded_height, const int padded_width, const int pad_head, const int pad_top, + const int pad_left, const float pad_value, float* output, cudaStream_t cuda_stream); +template void CalPad3d(const size_t size, const half* input, const int num, const int channels, + const int old_depth, const int old_height, const int old_width, const int padded_depth, + const int padded_height, const int padded_width, const int pad_head, const int pad_top, + const int pad_left, const float pad_value, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh index f387779b5c..2a6dcdd867 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -40,4 +40,9 @@ void CalPadGeneral(const size_t size, const T *input, const int num, const int c const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, const T pad_value, T *output, cudaStream_t cuda_stream); +template +void CalPad3d(const size_t size, const T* input, const int num, const int channels, const int old_depth, + const int old_height, const int old_width, const int padded_depth, const int padded_height, + const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value, + T* output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index e195003035..68fa80dd29 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -150,6 +150,12 @@ class GpuKernel : public KernelMod { dimA[1] = SizeToInt(shape[3]); dimA[2] = SizeToInt(shape[1]); dimA[3] = SizeToInt(shape[2]); + } else if (format == "NCDHW") { + dimA[0] = SizeToInt(shape[0]); + dimA[1] = SizeToInt(shape[1]); + dimA[2] = SizeToInt(shape[2]); + dimA[3] = SizeToInt(shape[3]); + dimA[4] = SizeToInt(shape[4]); } else { MS_LOG(ERROR) << "Unsupported data format " << format; } @@ -168,6 +174,12 @@ class GpuKernel : public KernelMod { strideA[1] = 1; strideA[2] = SizeToInt(shape[2] * shape[3]); strideA[3] = SizeToInt(shape[3]); + } else if (format == "NCDHW") { + strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3] * shape[4]); + strideA[1] = SizeToInt(shape[2] * shape[3] * shape[4]); + strideA[2] = SizeToInt(shape[3] * shape[4]); + strideA[3] = SizeToInt(shape[4]); + strideA[4] = 1; } else { MS_LOG(ERROR) << "Unsupported data format " << format; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.cc new file mode 100644 index 0000000000..735fe7725a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv3D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv3dGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv3D, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Conv3dGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h new file mode 100644 index 0000000000..5520d08b3f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h @@ -0,0 +1,389 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GPU_KERNEL_H_ + +#include +#include +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class Conv3dGpuKernel : public GpuKernel { + public: + Conv3dGpuKernel() { ResetResource(); } + ~Conv3dGpuKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *filter_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + T *workspace_addr = nullptr; + if (workspace_size_ != 0) { + workspace_addr = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + if (use_pad_) { + T *padded_addr = GetDeviceAddress(workspace, 1); + CalPad3d(padded_size_ / sizeof(T), input_addr, n_, c_, old_depth_, old_height_, old_width_, + old_depth_ + pad_depth_, old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, + pad_left_, pad_value_, padded_addr, reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnConvolutionForward(cudnn_handle_, &alpha, input_desc_, input_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + data_format_ = kOpFormat_NCDHW; + auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "Conv3dGpuKernel input is null."; + InitSizeLists(); + return true; + } + CHECK_TENSOR_SIZE(in_shape); + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_depth_ = SizeToInt(in_shape[2]); + old_height_ = SizeToInt(in_shape[3]); + old_width_ = SizeToInt(in_shape[3]); + compute_format_ = CUDNN_TENSOR_NCHW; + SetNDDesc(in_shape, filter_shape, output_shape); + group_ = static_cast(GetAttr(kernel_node, "group")); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_), + "cudnnSetConvGroupCount failed"); + std::vector pad_list; + std::vector pad_list_me = GetAttr>(kernel_node, "pad_list"); + (void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list), + [](const int64_t &value) { return static_cast(value); }); + pad_depth_ = pad_list[0]; + pad_height_ = pad_list[2]; + pad_width_ = pad_list[4]; + use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5])); + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t input_descriptor_real = nullptr; + const int kNumDims = 5; + const int kConvDims = 3; + int padA[kConvDims]; + int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]}; + int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]}; + if (use_pad_) { + pad_depth_ = pad_list[0] + pad_list[1]; + pad_height_ = pad_list[2] + pad_list[3]; + pad_width_ = pad_list[4] + pad_list[5]; + pad_head_ = pad_list[0]; + pad_top_ = pad_list[2]; + pad_left_ = pad_list[4]; + int dimA[kNumDims]; + int strideApadded[kNumDims]; + if (data_format_ == kOpFormat_NCDHW) { + auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_), + IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; + SetDimA(padded_shape, dimA, kNumDims, data_format_); + SetStrideA(padded_shape, strideApadded, kNumDims, data_format_); + } else { + MS_LOG(EXCEPTION) << "Conv3d only support NCDHW format right now."; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, kNumDims, dimA, strideApadded), + "cudnnSetTensor4dDescriptor failed"); + padA[0] = 0; + padA[1] = 0; + padA[2] = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, + CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolutionNdDescriptor failed"); + input_descriptor_real = padded_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_depth_ = 0; + pad_height_ = 0; + pad_width_ = 0; + } + padA[0] = pad_depth_; + padA[1] = pad_height_; + padA[2] = pad_width_; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA, + CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + input_descriptor_real = input_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(input_descriptor_real); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + cudnn_handle_ = nullptr; + input_desc_ = nullptr; + output_desc_ = nullptr; + filter_desc_ = nullptr; + conv_desc_ = nullptr; + padded_desc_ = nullptr; + cudnn_data_type_ = CUDNN_DATA_FLOAT; + compute_format_ = CUDNN_TENSOR_NCHW; + old_depth_ = 0; + old_height_ = 0; + old_width_ = 0; + pad_depth_ = 0; + pad_height_ = 0; + pad_width_ = 0; + pad_head_ = 0; + pad_top_ = 0; + pad_left_ = 0; + n_ = 0; + c_ = 0; + stride_.clear(); + dilation_.clear(); + group_ = 1; + is_null_input_ = false; + input_size_ = 0; + filter_size_ = 0; + output_size_ = 0; + padded_size_ = 0; + workspace_size_ = 0; + use_pad_ = true; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + void DestroyResource() noexcept override { + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(filter_desc_), + "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_desc_), + "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(output_desc_), + "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(input_desc_), + "cudnnDestroyTensorDescriptor failed"); + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&input_desc_), + "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&output_desc_), + "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_desc_), + "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&filter_desc_), + "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnGetFilterSizeInBytes(filter_desc_, reinterpret_cast(&filter_size_)), + "cudnnGetFilterSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + input_size_list_.push_back(filter_size_); + output_size_list_.push_back(output_size_); + if (use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, padded_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, input_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + + return; + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but conv3d needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but conv3d needs 1 output."; + return false; + } + return true; + } + + void SetNDDesc(const std::vector &in_shape, const std::vector &filter_shape, + const std::vector &output_shape) { + const int kDims = 5; + int dimA[kDims]; + int strideAin[kDims]; + int dimAout[kDims]; + int strideAout[kDims]; + int filterDimA[kDims]; + SetDimA(in_shape, dimA, kDims, data_format_); + SetStrideA(in_shape, strideAin, kDims, data_format_); + SetDimA(output_shape, dimAout, kDims, data_format_); + SetStrideA(output_shape, strideAout, kDims, data_format_); + SetDimA(filter_shape, filterDimA, kDims, data_format_); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, kDims, dimA, strideAin), + "cudnnSetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetFilterNdDescriptor(filter_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA), + "cudnnSetFilter4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetTensorNdDescriptor(output_desc_, cudnn_data_type_, kDims, dimAout, strideAout), + "cudnnSetTensor4dDescriptor failed"); + } + + void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { + const int requested_algo_count = 1; + int returned_algo_count = 0; + cudnnConvolutionFwdAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, + output_desc_, requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionForwardAlgorithm_v7 failed"); + conv_algorithm_ = perf_results.algo; + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } + } + + void SetStrideAndDilation(const CNodePtr &kernel_node) { + std::vector stride_me = AnfAlgo::GetNodeAttr>(kernel_node, "strides"); + std::vector dilation_me = AnfAlgo::GetNodeAttr>(kernel_node, "dilations"); + (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), + [](const int64_t &value) { return static_cast(value); }); + (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_), + [](const int64_t &value) { return static_cast(value); }); + if (stride_.size() != 5) { + MS_LOG(EXCEPTION) << "Conv3d's' stride must be 5d, but got " << stride_.size(); + } + if (stride_[0] != 1 || stride_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv3d stride only support 1 in N axis and C axis!"; + } + if (dilation_.size() != 5) { + MS_LOG(EXCEPTION) << "Conv3d's dilation must be 5d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv3d dilation only support 1 in N axis and C axis!"; + } + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionFwdAlgo_t conv_algorithm_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t padded_desc_; + std::string pad_mode_; + std::string data_format_ = kOpFormat_NCDHW; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + cudnnTensorFormat_t compute_format_; + int old_depth_; + int old_height_; + int old_width_; + int pad_depth_; + int pad_height_; + int pad_width_; + int pad_head_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t input_size_; + size_t filter_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_conv3d_op.py b/tests/st/ops/gpu/test_conv3d_op.py new file mode 100644 index 0000000000..7db6802b67 --- /dev/null +++ b/tests/st/ops/gpu/test_conv3d_op.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetConv3d(nn.Cell): + def __init__(self): + super(NetConv3d, self).__init__() + out_channel = 4 + kernel_size = 2 + self.conv = P.Conv3D(out_channel, + kernel_size, + mode=1, + pad_mode="valid", + pad=0, + stride=1, + dilation=1, + group=1) + + def construct(self, x, w): + return self.conv(x, w) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_conv3d(): + x = Tensor(np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32)) + w = Tensor(np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32)) + expect = np.array([[[[[12960., 13236.], + [13788., 14064.]], + [[15444., 15720.], + [16272., 16548.]]], + [[[32256., 33108.], + [34812., 35664.]], + [[39924., 40776.], + [42480., 43332.]]], + [[[51552., 52980.], + [55836., 57264.]], + [[64404., 65832.], + [68688., 70116.]]], + [[[70848., 72852.], + [76860., 78864.]], + [[88884., 90888.], + [94896., 96900.]]]]]).astype(np.float32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = NetConv3d() + output = net(x, w) + assert (output.asnumpy() == expect).all() + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = NetConv3d() + output = net(x, w) + assert (output.asnumpy() == expect).all()