Browse Source

update tensor size check func

tags/v1.3.0
VectorSL 5 years ago
parent
commit
2dbf0e694e
21 changed files with 37 additions and 33 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h
  2. +18
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h
  4. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h
  5. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h
  6. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h
  7. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h
  8. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h
  9. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h
  10. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h
  11. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h
  12. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h
  13. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h
  14. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h
  15. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h
  16. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h
  17. +1
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h
  18. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h
  19. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h
  20. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h
  21. +0
    -13
      mindspore/ccsrc/runtime/device/gpu/gpu_common.h

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h View File

@@ -216,7 +216,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
std::vector<size_t> inputA;
std::vector<size_t> outputC_shape = output_shape;
const int split_dim = 4;
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape, output_shape});
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(


+ 18
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -21,6 +21,7 @@
#include <cudnn.h>
#include <string>
#include <vector>
#include <initializer_list>
#include <utility>
#include <map>
#include <memory>
@@ -34,6 +35,9 @@
#include "runtime/device/executor/dynamic_kernel.h"
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
// The max_limit of tensor shape size: 2 Giga-elements(2^31, the largest number in 32 bits).
#define SHAPE_SIZE_LIMIT 2147483648
namespace mindspore {
namespace kernel {
static std::map<int, int> kNCHWToNHWCAxisMap = {
@@ -218,6 +222,20 @@ class GpuKernel : public KernelMod {
}
}
// The tensor size is limited to 2G by cudnn.
inline void CheckTensorSize(const std::initializer_list<std::vector<size_t>> &shapes) {
for (auto shape : shapes) {
size_t total_size = 1;
for (auto i : shape) {
total_size *= i;
}
if (total_size >= SHAPE_SIZE_LIMIT) {
MS_EXCEPTION(ValueError) << "The total size of the tensor exceeds the max_limit of 2 Giga-elements, which is "
<< total_size << " elements (" << shape << ").";
}
}
}
// set the tensor descriptor for cudnn/cublas
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
cudnnDataType_t data_type, const std::weak_ptr<CNode> &node) {


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h View File

@@ -76,7 +76,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) {


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h View File

@@ -84,7 +84,7 @@ class ActivationGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h View File

@@ -136,7 +136,7 @@ class BatchNormGpuKernel : public GpuKernel {
if (format_attr == kOpFormat_NHWC) {
format = kOpFormat_NHWC;
}
CHECK_TENSOR_SIZE(shape);
CheckTensorSize({shape});
SetTensorDescriptor(format, shape);
InitSizeLists();
return true;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h View File

@@ -160,7 +160,7 @@ class BatchNormGradGpuKernel : public GpuKernel {
format = kOpFormat_NHWC;
}
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
CHECK_TENSOR_SIZE(shape);
CheckTensorSize({shape});
SetTensorDescriptor(format, shape);
InitSizeLists();
is_train_ = GetAttr<bool>(kernel_node, "is_training");


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h View File

@@ -99,7 +99,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
CheckTensorSize({in_shape, filter_shape, output_shape});
SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
if (data_format_ == kOpFormat_NHWC) {
compute_format_ = CUDNN_TENSOR_NHWC;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h View File

@@ -118,7 +118,6 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
format_attr_ = GetAttr<std::string>(kernel_node, "format");
if (format_attr_ == kOpFormat_NHWC) {
@@ -126,6 +125,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
}
std::vector<size_t> filter_shape;
GetFilterShape(kernel_node, &filter_shape);
CheckTensorSize({in_shape, dy_shape, filter_shape});
if (data_format_ == kOpFormat_NHWC) {
compute_format_ = CUDNN_TENSOR_NHWC;
}


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h View File

@@ -133,7 +133,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
ShapeNCHW2NHWC(&input_shape);
}
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape, dy_shape, filter_shape});
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
Set4DDesc(dy_shape, input_shape, filter_shape);



+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_gpu_kernel.h View File

@@ -89,7 +89,7 @@ class Conv3dGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
CheckTensorSize({in_shape});
n_ = SizeToInt(in_shape[0]);
c_ = SizeToInt(in_shape[1]);
old_depth_ = SizeToInt(in_shape[2]);


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_filter_gpu_kernel.h View File

@@ -102,7 +102,7 @@ class Conv3dGradFilterGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
CheckTensorSize({in_shape});
data_format_ = kOpFormat_NCDHW;

std::vector<size_t> filter_shape;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv3d_grad_input_gpu_kernel.h View File

@@ -92,7 +92,7 @@ class Conv3dGradInputGpuKernel : public GpuKernel {
std::vector<size_t> input_shape;
GetInputShape(kernel_node, &input_shape);
compute_format_ = CUDNN_TENSOR_NCHW;
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape});
n_ = SizeToInt(input_shape[0]);
c_ = SizeToInt(input_shape[1]);
old_depth_ = SizeToInt(input_shape[2]);


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h View File

@@ -98,7 +98,7 @@ class Im2ColGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(in_shape);
CheckTensorSize({in_shape, output_shape});
Set4DDesc(in_shape, filter_shape, output_shape);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, 1),
"cudnnSetConvGroupCount failed");


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h View File

@@ -134,7 +134,7 @@ class InstanceNormGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(input_shape_);
CheckTensorSize({input_shape_});
SetTensorDescriptor();
InitSizeLists();
return true;


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h View File

@@ -131,7 +131,7 @@ class InstanceNormGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(input_shape_);
CheckTensorSize({input_shape_});
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor();
InitSizeLists();


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h View File

@@ -110,7 +110,7 @@ class L2NormalizeGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(inputA_shape);
CheckTensorSize({inputA_shape, output_shape});
if (inputA_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
}


+ 1
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h View File

@@ -262,8 +262,7 @@ class L2NormalizeGradGpuKernel : public GpuKernel {
std::vector<size_t> inputA;
std::vector<size_t> outputC_shape = output_shape;
constexpr int split_dim = 4;
CHECK_TENSOR_SIZE(input_shape);
CHECK_TENSOR_SIZE(output_shape);
CheckTensorSize({input_shape, output_shape});
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h View File

@@ -98,7 +98,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape, output_shape});
auto dim = input_shape.size();
if (dim == 4) {
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h View File

@@ -96,7 +96,7 @@ class PoolingGradGpuKernel : public GpuKernel {
InitSizeLists();
return false;
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape, input_mask, dout_shape, output_shape});
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format);
SetDimA(input_shape, dimA, nbDims, data_format);
SetStrideA(input_shape, strideAin, nbDims, data_format);


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h View File

@@ -137,7 +137,7 @@ class BatchNormFoldGpuKernel : public GpuKernel {
<< ", but BatchNormFold GpuKernel OP needs 4DTensor input.";
return false;
}
CHECK_TENSOR_SIZE(input_shape);
CheckTensorSize({input_shape});
batch_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];


+ 0
- 13
mindspore/ccsrc/runtime/device/gpu/gpu_common.h View File

@@ -201,19 +201,6 @@ inline bool CheckNullInput(const std::vector<size_t> &input_shape) {
}
#define CHECK_NULL_INPUT(input_shape) mindspore::device::gpu::CheckNullInput(input_shape)

// The tensor size is limited to 2G by cudnn.
inline void CheckTensorSize(const std::vector<size_t> &shape) {
size_t total_size = 1;
for (auto i : shape) {
total_size *= i;
}
if (total_size >= 2147483648) {
MS_EXCEPTION(ValueError) << "The total size of the tensor exceeds the max_limit of 2 Giga-elements, which is "
<< total_size << "elements (" << shape << ").";
}
}
#define CHECK_TENSOR_SIZE(shape) mindspore::device::gpu::CheckTensorSize(shape)

#define CHECK_CURAND_RET_WITH_EXCEPT(expression, message) \
{ \
curandStatus_t status = (expression); \


Loading…
Cancel
Save