Browse Source

!5071 GPU fix codex for SetDim and SetStride

Merge pull request !5071 from VectorSL/codex
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
de2ce9fde1
6 changed files with 51 additions and 45 deletions
  1. +8
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  2. +10
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h
  3. +9
    -9
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h
  4. +10
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h
  5. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h
  6. +9
    -9
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h

+ 8
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -80,7 +80,10 @@ class GpuKernel : public KernelMod {
std::swap((*shape)[2], (*shape)[1]); std::swap((*shape)[2], (*shape)[1]);
} }
void SetDimA(const std::vector<size_t> &shape, int *dimA, const std::string &format) {
void SetDimA(const std::vector<size_t> &shape, int *dimA, size_t len, const std::string &format) {
if (shape.size() != len) {
MS_EXCEPTION(ValueError) << "Invalid size of input shape " << shape.size() << "-D with dimA " << len << "-D.";
}
if (format == "NCHW" || format == "DefaultFormat") { if (format == "NCHW" || format == "DefaultFormat") {
dimA[0] = SizeToInt(shape[0]); dimA[0] = SizeToInt(shape[0]);
dimA[1] = SizeToInt(shape[1]); dimA[1] = SizeToInt(shape[1]);
@@ -95,7 +98,10 @@ class GpuKernel : public KernelMod {
MS_LOG(ERROR) << "Unsupported data format " << format; MS_LOG(ERROR) << "Unsupported data format " << format;
} }
} }
void SetStrideA(const std::vector<size_t> &shape, int *strideA, const std::string &format) {
void SetStrideA(const std::vector<size_t> &shape, int *strideA, size_t len, const std::string &format) {
if (shape.size() != len) {
MS_EXCEPTION(ValueError) << "Invalid size of input shape " << shape.size() << "-D with strideA " << len << "-D.";
}
if (format == "NCHW" || format == "DefaultFormat") { if (format == "NCHW" || format == "DefaultFormat") {
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]); strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]);
strideA[1] = SizeToInt(shape[2] * shape[3]); strideA[1] = SizeToInt(shape[2] * shape[3]);


+ 10
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h View File

@@ -147,13 +147,13 @@ class Conv2dGpuFwdKernel : public GpuKernel {
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)}; IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") { } else if (data_format_ == "NHWC") {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)}; IntToSize(c_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} }
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
@@ -259,18 +259,18 @@ class Conv2dGpuFwdKernel : public GpuKernel {


void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape, void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &output_shape) { const std::vector<size_t> &output_shape) {
int nbDims = 4;
const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
int dimAout[4]; int dimAout[4];
int strideAout[4]; int strideAout[4];
SetDimA(in_shape, dimA, data_format_);
SetStrideA(in_shape, strideAin, data_format_);
SetDimA(output_shape, dimAout, data_format_);
SetStrideA(output_shape, strideAout, data_format_);
SetDimA(in_shape, dimA, 4, data_format_);
SetStrideA(in_shape, strideAin, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
int filterDimA[4]; int filterDimA[4];
// OHWI for NHWC; OIHW for NCHW // OHWI for NHWC; OIHW for NCHW
SetDimA(filter_shape, filterDimA, data_format_);
SetDimA(filter_shape, filterDimA, 4, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");




+ 9
- 9
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h View File

@@ -148,13 +148,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)}; IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") { } else if (data_format_ == "NHWC") {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)}; IntToSize(c_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded),
@@ -283,15 +283,15 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
} }
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape, void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &in_shape) { const std::vector<size_t> &in_shape) {
int nbDims = 4;
const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
int dimAdy[4]; int dimAdy[4];
int strideAdy[4]; int strideAdy[4];
SetDimA(in_shape, dimA, data_format_);
SetStrideA(in_shape, strideAin, data_format_);
SetDimA(dy_shape, dimAdy, data_format_);
SetStrideA(dy_shape, strideAdy, data_format_);
SetDimA(in_shape, dimA, 4, data_format_);
SetStrideA(in_shape, strideAin, 4, data_format_);
SetDimA(dy_shape, dimAdy, 4, data_format_);
SetStrideA(dy_shape, strideAdy, 4, data_format_);
// filter shape always keep OIHW. // filter shape always keep OIHW.
int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]),
SizeToInt(filter_shape[3])}; SizeToInt(filter_shape[3])};


+ 10
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h View File

@@ -149,13 +149,13 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)}; IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") { } else if (data_format_ == "NHWC") {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)}; IntToSize(c_)};
SetDimA(padded_shape, dimA, data_format_);
SetStrideA(padded_shape, strideApadded, data_format_);
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded),
@@ -284,17 +284,17 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
} }
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape, void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape,
const std::vector<size_t> &filter_shape) { const std::vector<size_t> &filter_shape) {
int nbDims = 4;
const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
int dimAdy[4]; int dimAdy[4];
int strideAdy[4]; int strideAdy[4];
int filterDimA[4]; int filterDimA[4];
SetDimA(input_shape, dimA, data_format_);
SetStrideA(input_shape, strideAin, data_format_);
SetDimA(dy_shape, dimAdy, data_format_);
SetStrideA(dy_shape, strideAdy, data_format_);
SetDimA(filter_shape, filterDimA, data_format_);
SetDimA(input_shape, dimA, 4, data_format_);
SetStrideA(input_shape, strideAin, 4, data_format_);
SetDimA(dy_shape, dimAdy, 4, data_format_);
SetStrideA(dy_shape, strideAdy, 4, data_format_);
SetDimA(filter_shape, filterDimA, 4, data_format_);


CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed"); "cudnnSetTensorNdDescriptor failed");


+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h View File

@@ -88,15 +88,15 @@ class PoolingGpuFwdKernel : public GpuKernel {
return true; return true;
} }
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
int nbDims = 4;
const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
int dimAout[4]; int dimAout[4];
int strideAout[4]; int strideAout[4];
SetDimA(input_shape, dimA, data_format_);
SetStrideA(input_shape, strideAin, data_format_);
SetDimA(output_shape, dimAout, data_format_);
SetStrideA(output_shape, strideAout, data_format_);
SetDimA(input_shape, dimA, 4, data_format_);
SetStrideA(input_shape, strideAin, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");


+ 9
- 9
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h View File

@@ -100,7 +100,7 @@ class PoolingGradGpuKernel : public GpuKernel {
int windowDimA[2] = {window_height, window_width}; int windowDimA[2] = {window_height, window_width};
int paddingA[2] = {0, 0}; int paddingA[2] = {0, 0};
int strideA[2] = {stride_[2], stride_[3]}; int strideA[2] = {stride_[2], stride_[3]};
int nbDims = 4;
const int nbDims = 4;
int dimA[4]; int dimA[4];
int strideAin[4]; int strideAin[4];
int dimAy[4]; int dimAy[4];
@@ -109,14 +109,14 @@ class PoolingGradGpuKernel : public GpuKernel {
int strideAdy[4]; int strideAdy[4];
int dimAout[4]; int dimAout[4];
int strideAout[4]; int strideAout[4];
SetDimA(input_shape, dimA, data_format_);
SetStrideA(input_shape, strideAin, data_format_);
SetDimA(input_mask, dimAy, data_format_);
SetStrideA(input_mask, strideAiny, data_format_);
SetDimA(dout_shape, dimAdy, data_format_);
SetStrideA(dout_shape, strideAdy, data_format_);
SetDimA(output_shape, dimAout, data_format_);
SetStrideA(output_shape, strideAout, data_format_);
SetDimA(input_shape, dimA, 4, data_format_);
SetStrideA(input_shape, strideAin, 4, data_format_);
SetDimA(input_mask, dimAy, 4, data_format_);
SetStrideA(input_mask, strideAiny, 4, data_format_);
SetDimA(dout_shape, dimAdy, 4, data_format_);
SetStrideA(dout_shape, strideAdy, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny),
"cudnnSetTensor4dDescriptor failed"); "cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy),


Loading…
Cancel
Save