Browse Source

!16022 GPU Conv3d grad op support

From: @tom__chen
Reviewed-by: @robingrosman,@mikef
Signed-off-by:
pull/16022/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
1dc0efbab5
8 changed files with 1011 additions and 5 deletions
  1. +45
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu
  2. +6
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh
  3. +30
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.cc
  4. +418
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h
  5. +30
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.cc
  6. +397
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h
  7. +2
    -2
      mindspore/nn/layer/conv.py
  8. +83
    -0
      tests/st/ops/gpu/test_conv3d_op.py

+ 45
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -129,7 +129,7 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in
const int pos_w = pos % padded_width; const int pos_w = pos % padded_width;
const int block_num = pos / padded_dhw; const int block_num = pos / padded_dhw;


if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_h - pad_head >= old_depth ||
if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_d - pad_head >= old_depth ||
pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) { pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) {
output[pos] = pad_value_; output[pos] = pad_value_;
} else { } else {
@@ -140,6 +140,23 @@ __global__ void Pad3d(const size_t size, const T* input, const int num, const in
return; return;
} }


template <typename T>
__global__ void PadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int old_dhw, const int old_hw,
const int padded_depth, const int padded_height, const int padded_width,
const int padded_dhw, const int padded_hw, const int pad_head, const int pad_top,
const int pad_left, T* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
const int block_num = pos / old_dhw;
const int pos_d = pos / old_hw % old_depth + pad_head;
const int pos_h = pos / old_width % old_height + pad_top;
const int pos_w = pos % old_width + pad_left;
const int index = block_num * padded_dhw + pos_d * padded_hw + pos_h * padded_width + pos_w;
dx[pos] = dy[index];
}
return;
}

template <typename T> template <typename T>
void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height,
const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
@@ -204,6 +221,22 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe
return; return;
} }


template <typename T>
void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx,
cudaStream_t cuda_stream) {
const int old_hw = old_height * old_width;
const int old_dhw = old_depth * old_hw;
const int padded_hw = padded_height * padded_width;
const int padded_dhw = padded_depth * padded_hw;
PadGrad3d<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, channels, old_depth, old_height,
old_width, old_dhw, old_hw, padded_depth, padded_height,
padded_width, padded_dhw, padded_hw, pad_head, pad_top,
pad_left, dx);
return;
}

template void CalPad<float>(const size_t size, const float* input, const int num, const int channels, template void CalPad<float>(const size_t size, const float* input, const int num, const int channels,
const int old_height, const int old_width, const int padded_height, const int padded_width, const int old_height, const int old_width, const int padded_height, const int padded_width,
const int pad_top, const int pad_left, float pad_value, float* output, const int pad_top, const int pad_left, float pad_value, float* output,
@@ -259,3 +292,13 @@ template void CalPad3d<half>(const size_t size, const half* input, const int num
const int old_depth, const int old_height, const int old_width, const int padded_depth, const int old_depth, const int old_height, const int old_width, const int padded_depth,
const int padded_height, const int padded_width, const int pad_head, const int pad_top, const int padded_height, const int padded_width, const int pad_head, const int pad_top,
const int pad_left, const float pad_value, half* output, cudaStream_t cuda_stream); const int pad_left, const float pad_value, half* output, cudaStream_t cuda_stream);
template void CalPadGrad3d<float>(const size_t size, const float* dy, const int num, const int channels,
const int old_depth, const int old_height, const int old_width,
const int padded_depth, const int padded_height, const int padded_width,
const int pad_head, const int pad_top, const int pad_left, float* dx,
cudaStream_t cuda_stream);
template void CalPadGrad3d<half>(const size_t size, const half* dy, const int num, const int channels,
const int old_depth, const int old_height, const int old_width,
const int padded_depth, const int padded_height, const int padded_width,
const int pad_head, const int pad_top, const int pad_left, half* dx,
cudaStream_t cuda_stream);

+ 6
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -45,4 +45,9 @@ void CalPad3d(const size_t size, const T* input, const int num, const int channe
const int old_height, const int old_width, const int padded_depth, const int padded_height, const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value, const int padded_width, const int pad_head, const int pad_top, const int pad_left, const float pad_value,
T* output, cudaStream_t cuda_stream); T* output, cudaStream_t cuda_stream);
template <typename T>
void CalPadGrad3d(const size_t size, const T* dy, const int num, const int channels, const int old_depth,
const int old_height, const int old_width, const int padded_depth, const int padded_height,
const int padded_width, const int pad_head, const int pad_top, const int pad_left, T* dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_

+ 30
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradFilterGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 418
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h View File

@@ -0,0 +1,418 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_

#include <algorithm>
#include <string>
#include <vector>

#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"

namespace mindspore {
namespace kernel {
template <typename T>
class Conv3dGradFilterGpuKernel : public GpuKernel {
public:
Conv3dGradFilterGpuKernel() { ResetResource(); }
~Conv3dGradFilterGpuKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *x = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);

T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}

T *dw = nullptr;
float *dw_float32 = nullptr;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
dw = GetDeviceAddress<T>(workspace, 1);
dw_float32 = GetDeviceAddress<float>(outputs, 0);
} else {
dw = GetDeviceAddress<T>(outputs, 0);
}

const float alpha = 1;
const float beta = 0;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
CalPad3d(padded_size_ / sizeof(T), x, n_, c_, old_depth_, old_height_, old_width_, old_depth_ + pad_depth_,
old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_, pad_left_, pad_value_, padded,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_,
algo_, work_space, workspace_size_, &beta, dw_desc_, dw),
"ConvolutionBackwardFilter failed");
return true;
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta, dw_desc_, dw),
"ConvolutionBackwardFilter failed");

if (cudnn_data_type_ == CUDNN_DATA_HALF) {
Cast(num_output_elements_, dw, dw_float32, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}

bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
if (!CheckParam(kernel_node)) {
return false;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "Conv3dGradFilterGpuKernel input is null.";
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
data_format_ = kOpFormat_NCDHW;

std::vector<size_t> filter_shape;
GetFilterShape(kernel_node, &filter_shape);
num_output_elements_ = 1;
for (auto x : filter_shape) {
num_output_elements_ *= x;
}

compute_format_ = CUDNN_TENSOR_NCHW;
n_ = SizeToInt(in_shape[0]);
c_ = SizeToInt(in_shape[1]);
old_depth_ = SizeToInt(in_shape[2]);
old_height_ = SizeToInt(in_shape[3]);
old_width_ = SizeToInt(in_shape[3]);
SetNDDesc(dy_shape, filter_shape, in_shape);
group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group"));
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_),
"cudnnSetConvGroupCount failed");
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
pad_depth_ = pad_list[0];
pad_height_ = pad_list[2];
pad_width_ = pad_list[4];
use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5]));
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t x_desc_real = nullptr;
const int kNumDims = 5;
const int kConvDims = 3;
int padA[kConvDims];
int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]};
int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]};
if (use_pad_) {
pad_depth_ = pad_list[0] + pad_list[1];
pad_height_ = pad_list[2] + pad_list[3];
pad_width_ = pad_list[4] + pad_list[5];
pad_head_ = pad_list[0];
pad_top_ = pad_list[2];
pad_left_ = pad_list[4];
int dimA[kNumDims];
int strideApadded[kNumDims];
if (data_format_ == kOpFormat_NCDHW) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_),
IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, kNumDims, data_format_);
SetStrideA(padded_shape, strideApadded, kNumDims, data_format_);
} else {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel only support NCDHW format right now.";
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded),
"cudnnSetTensor4dDescriptor failed");
padA[0] = 0;
padA[1] = 0;
padA[2] = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
x_desc_real = padded_descriptor_;
} else {
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
}
padA[0] = pad_depth_;
padA[1] = pad_height_;
padA[2] = pad_width_;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
x_desc_real = x_desc_;
}
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(x_desc_real);
InitSizeLists();
return true;
}

void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
dw_desc_ = nullptr;
conv_desc_ = nullptr;
dy_desc_ = nullptr;
x_desc_ = nullptr;
padded_descriptor_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
compute_format_ = CUDNN_TENSOR_NCHW;
old_depth_ = 0;
old_height_ = 0;
old_width_ = 0;
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
pad_head_ = 0;
pad_top_ = 0;
pad_left_ = 0;
n_ = 0;
c_ = 0;
group_ = 1;
is_null_input_ = false;
input_size_ = 0;
dy_size_ = 0;
output_size_ = 0;
padded_size_ = 0;
workspace_size_ = 0;
use_pad_ = true;
num_output_elements_ = 1;
}

void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(dw_desc_),
"cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_),
"cudnnDestroyTensorDescriptor failed");
}

protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&dw_desc_),
"cudnnCreateFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_),
"cudnnCreateConvolutionDescriptor failed");
}

void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast<size_t *>(&dy_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast<size_t *>(&input_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast<size_t *>(&output_size_)),
"cudnnGetFilterSizeInBytes failed");
}
input_size_list_.push_back(dy_size_);
input_size_list_.push_back(input_size_);

if (use_pad_ && !is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast<size_t *>(&padded_size_)),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_,
dw_desc_, algo_, reinterpret_cast<size_t *>(&workspace_size_)),
"cudnnGetConvolutionBackwardFilterWorkspaceSize failed");
workspace_size_list_.push_back(padded_size_);
} else {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_,
reinterpret_cast<size_t *>(&workspace_size_)),
"cudnnGetConvolutionBackwardFilterWorkspaceSize failed");
}
}
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_);

if (cudnn_data_type_ == CUDNN_DATA_HALF) {
workspace_size_list_.push_back(output_size_);
output_size_list_.push_back(num_output_elements_ * sizeof(float));
} else {
output_size_list_.push_back(output_size_);
}
}

private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradFilterGpuKernel needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradFilterGpuKernel needs 1 output.";
return false;
}
return true;
}

void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) {
const int requested_algo_count = 1;
int returned_algo_count = 0;
cudnnConvolutionBwdFilterAlgoPerf_t perf_results;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_,
requested_algo_count, &returned_algo_count, &perf_results),
"GetConvolutionBackwardFilterAlgorithm failed");
algo_ = perf_results.algo;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
}

void GetFilterShape(const CNodePtr &kernel_node, std::vector<size_t> *filter_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_size")->cast<ValueTuplePtr>()->value();
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape),
[](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); });
}

void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &in_shape) {
const int kDims = 5;
int dimA[kDims];
int strideAin[kDims];
int dimAdy[kDims];
int strideAdy[kDims];
int filterDimA[kDims];
SetDimA(in_shape, dimA, kDims, data_format_);
SetStrideA(in_shape, strideAin, kDims, data_format_);
SetDimA(dy_shape, dimAdy, kDims, data_format_);
SetStrideA(dy_shape, strideAdy, kDims, data_format_);
SetDimA(filter_shape, filterDimA, kDims, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA),
"cudnnSetFilterNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, kDims, dimA, strideAin),
"cudnnSetTensorNdDescriptor failed");
}

void SetStrideAndDilation(const CNodePtr &kernel_node) {
std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides");
std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations");
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride must be 5d, but got " << stride_.size();
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel stride only support 1 in N axis and C axis!";
}
if (dilation_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation must be 5d!";
}
if (dilation_[0] != 1 || dilation_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradFilterGpuKernel dilation only support 1 in N axis and C axis!";
}
}

cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t dw_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdFilterAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = kOpFormat_NCDHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
const float pad_value_ = 0.0;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_depth_;
int old_height_;
int old_width_;
int pad_depth_;
int pad_height_;
int pad_width_;
int pad_head_;
int pad_top_;
int pad_left_;
int n_;
int c_;
std::vector<int> stride_;
std::vector<int> dilation_;
int group_;
bool is_null_input_;
size_t input_size_;
size_t dy_size_;
size_t output_size_;
size_t padded_size_;
size_t workspace_size_;
bool use_pad_;
size_t num_output_elements_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_FILTER_GPU_KERNEL_H_

+ 30
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv3dGradInputGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Conv3DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Conv3dGradInputGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 397
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h View File

@@ -0,0 +1,397 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_

#include <algorithm>
#include <string>
#include <vector>

#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"

namespace mindspore {
namespace kernel {
template <typename T>
class Conv3dGradInputGpuKernel : public GpuKernel {
public:
Conv3dGradInputGpuKernel() { ResetResource(); }
~Conv3dGradInputGpuKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *w = GetDeviceAddress<T>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 1);
T *dx = GetDeviceAddress<T>(outputs, 0);
T *work_space = nullptr;
if (workspace_size_ != 0) {
work_space = GetDeviceAddress<T>(workspace, 0);
}

const float alpha = 1;
if (use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);

CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta_, padded_descriptor_, padded),
"ConvolutionBackwardData failed");
CalPadGrad3d(output_size_ / sizeof(T), padded, n_, c_, old_depth_, old_height_, old_width_,
old_depth_ + pad_depth_, old_height_ + pad_height_, old_width_ + pad_width_, pad_head_, pad_top_,
pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta_, dx_desc_, dx),
"ConvolutionBackwardData failed");
}
return true;
}

bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
InitResource();
if (!CheckParam(kernel_node)) {
return false;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
data_format_ = kOpFormat_NCDHW;
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(dy_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "Conv3dGradInputGpuKernel input is null.";
InitSizeLists();
return true;
}
std::vector<size_t> input_shape;
GetInputShape(kernel_node, &input_shape);
compute_format_ = CUDNN_TENSOR_NCHW;
CHECK_TENSOR_SIZE(input_shape);
n_ = SizeToInt(input_shape[0]);
c_ = SizeToInt(input_shape[1]);
old_depth_ = SizeToInt(input_shape[2]);
old_height_ = SizeToInt(input_shape[3]);
old_width_ = SizeToInt(input_shape[3]);
SetNDDesc(dy_shape, input_shape, filter_shape);
group_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "group"));
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, group_),
"cudnnSetConvGroupCount failed");
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
pad_depth_ = pad_list[0];
pad_height_ = pad_list[2];
pad_width_ = pad_list[4];
use_pad_ = !((pad_depth_ == pad_list[1]) && (pad_height_ == pad_list[3]) && (pad_width_ == pad_list[5]));
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
SetStrideAndDilation(kernel_node);
cudnnTensorDescriptor_t dx_desc_real = nullptr;
const int kNumDims = 5;
const int kConvDims = 3;
int padA[kConvDims];
int strideA[kConvDims] = {stride_[2], stride_[3], stride_[4]};
int dilaA[kConvDims] = {dilation_[2], dilation_[3], dilation_[4]};
if (use_pad_) {
pad_depth_ = pad_list[0] + pad_list[1];
pad_height_ = pad_list[2] + pad_list[3];
pad_width_ = pad_list[4] + pad_list[5];
pad_head_ = pad_list[0];
pad_top_ = pad_list[2];
pad_left_ = pad_list[4];
int dimA[kNumDims];
int strideApadded[kNumDims];
if (data_format_ == kOpFormat_NCDHW) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_depth_ + pad_depth_),
IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, kNumDims, data_format_);
SetStrideA(padded_shape, strideApadded, kNumDims, data_format_);
} else {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel only support NCDHW format right now.";
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, kNumDims, dimA, strideApadded),
"cudnnSetTensorNdDescriptor failed");
padA[0] = 0;
padA[1] = 0;
padA[2] = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolutionNdDescriptor failed");
dx_desc_real = padded_descriptor_;
} else {
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
}
padA[0] = pad_depth_;
padA[1] = pad_height_;
padA[2] = pad_width_;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetConvolutionNdDescriptor(conv_desc_, kConvDims, padA, strideA, dilaA,
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
"cudnnSetConvolution2dDescriptor failed");
dx_desc_real = dx_desc_;
}
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
"cudnnSetConvolutionMathType failed.")
}
SelectAlgorithm(dx_desc_real);
beta_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
InitSizeLists();
return true;
}

void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
w_desc_ = nullptr;
conv_desc_ = nullptr;
dy_desc_ = nullptr;
dx_desc_ = nullptr;
padded_descriptor_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
compute_format_ = CUDNN_TENSOR_NCHW;
old_depth_ = 0;
old_height_ = 0;
old_width_ = 0;
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
pad_head_ = 0;
pad_top_ = 0;
pad_left_ = 0;
n_ = 0;
c_ = 0;
group_ = 1;
is_null_input_ = false;
dy_size_ = 0;
w_size_ = 0;
output_size_ = 0;
padded_size_ = 0;
workspace_size_ = 0;
use_pad_ = true;
beta_ = 0;
}

void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_),
"cudnnDestroyConvolutionDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyFilterDescriptor(w_desc_),
"cudnnDestroyFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(padded_descriptor_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_),
"cudnnDestroyTensorDescriptor failed");
}

protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&padded_descriptor_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateFilterDescriptor(&w_desc_),
"cudnnCreateFilterDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateConvolutionDescriptor(&conv_desc_),
"cudnnCreateConvolutionDescriptor failed");
}

void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetFilterSizeInBytes(w_desc_, &w_size_),
"cudnnGetTensorSizeInBytes failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dx_desc_, &output_size_),
"cudnnGetTensorSizeInBytes failed");
}

input_size_list_.push_back(dy_size_);
input_size_list_.push_back(w_size_);
output_size_list_.push_back(output_size_);

if (use_pad_ && !is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_),
"cudnnGetTensorSizeInBytes failed");

CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_,
algo_, &workspace_size_),
"cudnnGetConvolutionBackwardDataWorkspaceSize failed");
workspace_size_list_.push_back(padded_size_);
} else {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetConvolutionBackwardDataWorkspaceSize(
cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_),
"cudnnGetConvolutionBackwardDataWorkspaceSize failed");
}
}
(void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_);
}

private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Conv3dGradInputGpuKernel needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Conv3dGradInputGpuKernel needs 1 output.";
return false;
}
return true;
}

void SetPad(const std::vector<int> &input_shape, const CNodePtr &kernel_node) {
std::vector<int> pad_list;
std::vector<int64_t> pad_list_me = GetAttr<std::vector<int64_t>>(kernel_node, "pad_list");
(void)std::transform(pad_list_me.begin(), pad_list_me.end(), std::back_inserter(pad_list),
[](const int64_t &value) { return static_cast<int>(value); });
}

void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
const int requested_algo_count = 1;
int returned_algo_count = 0;
cudnnConvolutionBwdDataAlgoPerf_t perf_results;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real,
requested_algo_count, &returned_algo_count, &perf_results),
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
algo_ = perf_results.algo;
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
}

void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) {
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_size")->cast<ValueTuplePtr>()->value();
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape),
[](const ValuePtr &e) -> size_t { return static_cast<int>(e->cast<Int64ImmPtr>()->value()); });
}

void SetNDDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape,
const std::vector<size_t> &filter_shape) {
const int kDims = 5;
int dimA[kDims];
int strideAin[kDims];
int dimAdy[kDims];
int strideAdy[kDims];
int filterDimA[kDims];
SetDimA(input_shape, dimA, kDims, data_format_);
SetStrideA(input_shape, strideAin, kDims, data_format_);
SetDimA(dy_shape, dimAdy, kDims, data_format_);
SetStrideA(dy_shape, strideAdy, kDims, data_format_);
SetDimA(filter_shape, filterDimA, kDims, data_format_);

CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, kDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, kDims, filterDimA),
"cudnnSetFilterNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, kDims, dimA, strideAin),
"cudnnSetTensorNdDescriptor failed");
}

void SetStrideAndDilation(const CNodePtr &kernel_node) {
std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides");
std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilations");
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride must be 5d, but got " << stride_.size();
}
if (stride_[0] != 1 || stride_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel stride only support 1 in N axis and C axis!";
}
if (dilation_.size() != 5) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation must be 5d!";
}
if (dilation_[0] != 1 || dilation_[1] != 1) {
MS_LOG(EXCEPTION) << "Conv3dGradInputGpuKernel dilation only support 1 in N axis and C axis!";
}
}

cudnnHandle_t cudnn_handle_;
cudnnFilterDescriptor_t w_desc_;
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdDataAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = kOpFormat_NCDHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_depth_;
int old_height_;
int old_width_;
int pad_depth_;
int pad_height_;
int pad_width_;
int pad_head_;
int pad_top_;
int pad_left_;
int n_;
int c_;
std::vector<int> stride_;
std::vector<int> dilation_;
int group_;
bool is_null_input_;
size_t dy_size_;
size_t w_size_;
size_t output_size_;
size_t padded_size_;
size_t workspace_size_;
bool use_pad_;
float beta_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV3D_GRAD_INPUT_GPU_KERNEL_H_

+ 2
- 2
mindspore/nn/layer/conv.py View File

@@ -58,8 +58,8 @@ class _Conv(Cell):
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.cls_name) self.format = Validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC": if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.") raise ValueError("NHWC format only support in GPU target.")
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend target.")
if context.get_context("device_target") == "CPU" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend and GPU targets.")
if isinstance(padding, int): if isinstance(padding, int):
Validator.check_non_negative_int(padding, 'padding', self.cls_name) Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding self.padding = padding


+ 83
- 0
tests/st/ops/gpu/test_conv3d_op.py View File

@@ -19,7 +19,9 @@ import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C




class NetConv3d(nn.Cell): class NetConv3d(nn.Cell):
@@ -71,3 +73,84 @@ def test_conv3d():
net = NetConv3d() net = NetConv3d()
output = net(x, w) output = net(x, w)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()


class MSConv3dNet(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size, pad_mode='pad', padding=0, stride=1, dilation=1,
has_bias=False, weight_init='normal'):
super(MSConv3dNet, self).__init__()
self.cv1 = nn.Conv3d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
pad_mode=pad_mode,
padding=padding,
stride=stride,
dilation=dilation,
group=1,
has_bias=has_bias,
weight_init=weight_init,
data_format='NCDHW')

def construct(self, x):
x = self.cv1(x)
return x


class MSGradNet(nn.Cell):
def __init__(self, network):
super(MSGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())

def construct(self, x, dy):
grad_op = self.grad(self.network, self.params)
output = grad_op(x, dy)
return output


def test_conv3d_grad():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dtype = np.float32
out_c = 2
kernel_size = (2, 2, 2)
x = Tensor(np.array([[[[[1.6924546, 0.05080776, -0.6369957],
[0.19091548, 2.1002553, 0.12015896],
[0.6172031, 0.30017033, -0.35224986]],
[[-1.1425182, -0.34934273, -0.20889424],
[0.5866232, 0.8389834, 0.9311021],
[0.2855873, 0.8851412, -0.7543979]],
[[1.2528682, 0.5129298, -0.29809284],
[0.48851815, -0.07557172, 1.1316293],
[1.5198169, 2.1855755, -1.3964963]]]]]).astype(dtype))
dy = Tensor(np.array([[[[[-1.4441139, -0.5044659],
[0.16003707, 0.8761689]],
[[0.31563494, -2.0222013],
[-0.30620402, 0.8279746]]],
[[[0.23009473, 0.7620112],
[-0.22232814, -0.20075807]],
[[0.18656139, 0.41005164],
[0.19829972, 0.11900865]]]]]).astype(dtype))
w = Tensor(np.array([[[[[-0.9358, -0.2679],
[0.5304, -0.6917]],
[[-0.3968, -0.6872],
[-0.8452, -0.6712]]]],
[[[[-0.0127, -1.1173],
[0.2344, 1.6598]],
[[0.7420, -0.1918],
[-0.8876, -0.7472]]]]]).astype(dtype))
w_exp = np.array([[[[[-0.9384, -0.2830],
[0.5487, -0.6330]],
[[-0.4148, -0.7200],
[-0.8572, -0.6079]]]],
[[[[-0.0109, -1.1089],
[0.2138, 1.6478]],
[[0.7450, -0.1866],
[-0.8992, -0.7629]]]]]).astype(dtype)
net = MSConv3dNet(x.shape[1], out_c, kernel_size, weight_init=w)
grad_net = MSGradNet(net)
optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9)
grad_net.set_train(True)
output = grad_net(x, dy)
optimizer(output[1])
assert np.allclose(net.cv1.weight.asnumpy(), w_exp, atol=1.0e-4)

Loading…
Cancel
Save