Browse Source

!18392 fix the exception when occur error and replace magic number with const value.

Merge pull request !18392 from wangshuide/wsd_master
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
6c33e0b710
40 changed files with 173 additions and 102 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h
  2. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h
  3. +0
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/extract_image_patches_gpu_kernel.h
  4. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/meshgrid_gpu_kernel.h
  5. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sort_gpu_kernel.h
  6. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h
  7. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  8. +15
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h
  9. +10
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h
  10. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumprod_gpu_kernel.h
  11. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h
  12. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.h
  13. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h
  14. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h
  15. +6
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h
  16. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h
  17. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h
  18. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h
  19. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h
  20. +8
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h
  21. +8
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h
  22. +17
    -15
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h
  23. +19
    -17
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h
  24. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h
  25. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h
  26. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h
  27. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h
  28. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h
  29. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h
  30. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h
  31. +13
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h
  32. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h
  33. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h
  34. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_gpu_kernel.h
  35. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h
  36. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.h
  37. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h
  38. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h
  39. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h
  40. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h View File

@@ -24,6 +24,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t SHAPE_SIZE = 4;
template <typename T>
class BroadcastToGpuKernel : public GpuKernel {
public:
@@ -47,8 +48,8 @@ class BroadcastToGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (input_shapes.size() > 4 || output_shapes.size() > 4) {
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4";
if (input_shapes.size() > SHAPE_SIZE || output_shapes.size() > SHAPE_SIZE) {
MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than " << SHAPE_SIZE;
}

size_t offset = output_shapes.size() - input_shapes.size();


+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h View File

@@ -53,6 +53,10 @@ class DepthToSpaceFwdKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
if (block_size_ == 0) {
MS_LOG(ERROR) << "block_size_ can not be 0.";
return false;
}
// check input num and output num
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {


+ 0
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/extract_image_patches_gpu_kernel.h View File

@@ -86,12 +86,10 @@ class ExtractImagePatchesKernel : public GpuKernel {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ExtractImagePatches needs 1 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ExtractImagePatches has 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_size_ = 1;


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/meshgrid_gpu_kernel.h View File

@@ -73,6 +73,7 @@ class MeshgridGpuKernel : public GpuKernel {
swap_indexing_ = false;
} else {
MS_LOG(ERROR) << "invalid string for argument \"indexing\", must be \"xy\" or \"ij\" but got " << indexing;
return false;
}

input_size_ = 1;


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sort_gpu_kernel.h View File

@@ -131,6 +131,7 @@ class SortGpuKernel : public GpuKernel {
input_rank_ = input_shape_.size();
if (input_rank_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(ERROR) << "Sort cannot support input that has more than " << TRANSPOSE_MAX_DIMENSION << " dimensions.";
return false;
}

input_size_ = 1;


+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h View File

@@ -53,6 +53,10 @@ class SpaceToDepthFwdKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
if (block_size_ == 0) {
MS_LOG(ERROR) << "block_size_ can not be 0.";
return false;
}
// check input num and output num
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -71,8 +71,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
auto shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
need_broadcast_ = AnfAlgo::IsTensorBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
if (need_broadcast_ && shape1.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
}

lhs_shape_.resize(MAX_DIMS, 1);


+ 15
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_solve_gpu_kernel.h View File

@@ -106,28 +106,35 @@ class CholeskyGpuKernel : public GpuKernel {
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
split_dim = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
if (split_dim == 0) {
InitNoSpltDim(in_shape);
if (!InitNoSpltDim(in_shape)) {
return false;
}
} else {
InitSpltDim(in_shape);
if (!InitSpltDim(in_shape)) {
return false;
}
}
return true;
}

protected:
void InitNoSpltDim(const std::vector<size_t> &in_shape) {
bool InitNoSpltDim(const std::vector<size_t> &in_shape) {
use_split_matrix = false;
if (in_shape.size() == 2) {
batch_ = 1;
if (in_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
} else if (in_shape.size() == 3) {
batch_ = SizeToInt(in_shape[0]);
if (in_shape[1] != in_shape[2]) {
MS_LOG(ERROR) << "Cholesky need square matrix as input.";
return false;
}
} else {
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
return false;
}

m_ = SizeToInt(in_shape[1]);
@@ -146,16 +153,19 @@ class CholeskyGpuKernel : public GpuKernel {
}
}
InitSizeLists();
return true;
}

void InitSpltDim(const std::vector<size_t> &in_shape) {
bool InitSpltDim(const std::vector<size_t> &in_shape) {
if (in_shape.size() != 2) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Input Rank as 2.";
return false;
}
height = in_shape[0];
width = in_shape[1];
if (height != width) {
MS_LOG(ERROR) << "Cholesky Split Matrix Need Square Matrix as Input.";
return false;
}
if (SizeToInt(height) <= split_dim) {
use_split_matrix = false;
@@ -202,6 +212,7 @@ class CholeskyGpuKernel : public GpuKernel {
}
InitSizeLists();
}
return true;
}

void InitSizeLists() override {


+ 10
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cholesky_trsm_solve_gpu_kernel.h View File

@@ -57,13 +57,17 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
split_dim = static_cast<int>(GetAttr<int64_t>(kernel_node, "split_dim"));
if (split_dim == 0) {
InitDim0(kernel_node, in_shape);
if (!InitDim0(kernel_node, in_shape)) {
return false;
}
} else {
if (in_shape.size() != 2) {
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Input Rank as 2.";
return false;
}
if (in_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "CholeskyTrsm Split Matrix Need Square Matrix as Input.";
return false;
}
InitDimOthers(kernel_node, in_shape);
}
@@ -170,20 +174,23 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
d_array_addr, lda_, d_identity_addr, ldb_, batch_),
"cublas trsm batched Fail");
}
void InitDim0(const CNodePtr &kernel_node, const std::vector<size_t> &in_shape) {
bool InitDim0(const CNodePtr &kernel_node, const std::vector<size_t> &in_shape) {
use_split_matrix = false;
if (in_shape.size() == 2) {
batch_ = 1;
if (in_shape[0] != in_shape[1]) {
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
return false;
}
} else if (in_shape.size() == 3) {
batch_ = SizeToInt(in_shape[0]);
if (in_shape[1] != in_shape[2]) {
MS_LOG(ERROR) << "CholeskyTrsm need square matrix as input.";
return false;
}
} else {
MS_LOG(ERROR) << "Input Only support Rank 2 OR 3";
return false;
}

m_ = SizeToInt(in_shape[1]);
@@ -201,6 +208,7 @@ class CholeskyTrsmGpuKernel : public GpuKernel {
}
}
}
return true;
}
void InitDimOthers(const CNodePtr &kernel_node, const std::vector<size_t> &in_shape) {
height = in_shape[0];


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumprod_gpu_kernel.h View File

@@ -47,7 +47,6 @@ class CumProdGpuKernel : public GpuKernel {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumProdGpuKernel needs 1.";
return false;
}
input_size_0_ = sizeof(T);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h View File

@@ -47,7 +47,6 @@ class CumSumGpuKernel : public GpuKernel {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1.";
return false;
}
input_size_0_ = sizeof(T);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/linspace.h View File

@@ -69,6 +69,7 @@ class LinSpaceGpuKernel : public GpuKernel {
if (value_count.size() != 1) {
MS_LOG(ERROR) << "For LinShape, output shape incorrect rank. Expect Rank: 1, got Rank: " << value_count.size()
<< ".";
return false;
}
value_count_ = value_count[0];
InitSizeLists();


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/square_sum_all_gpu_kernel.h View File

@@ -54,7 +54,8 @@ class SquareSumAllGpuFwdKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "SquareSumAllGpuFwdKernel input is null";
MS_LOG(ERROR) << "SquareSumAllGpuFwdKernel input is null";
return false;
}
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h View File

@@ -58,7 +58,7 @@ class SquaredDifferenceOpGpuKernel : public GpuKernel {
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
need_broadcast_ = IsBroadcast(input_shape1, input_shape2);
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
}

lhs_shape_.resize(MAX_DIMS, 1);


+ 6
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/math/update_thor_gradient.h View File

@@ -142,7 +142,9 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
SetProperty(kernel_node);
if (!SetProperty(kernel_node)) {
return false;
}
InitSizeLists();
return true;
}
@@ -181,7 +183,7 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
}

private:
void SetProperty(const CNodePtr &kernel_node) {
bool SetProperty(const CNodePtr &kernel_node) {
auto matrix_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto matrix_g_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
@@ -189,6 +191,7 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
split_dim = LongToSize(GetAttr<int64_t>(kernel_node, "split_dim"));
if (split_dim == 0) {
MS_LOG(ERROR) << "Divide by zero, split_dim can not be zero.";
return false;
}
gradient_size.batch_h = gradient_shape[0] / split_dim;
gradient_size.batch_w = gradient_shape[1] / split_dim;
@@ -229,6 +232,7 @@ class UpdateThorGradientGpuKernel : public GpuKernel {
gradient_size.ori_w = gradient_shape[1];
gradient_size.ori_h = gradient_shape[0];
gradient_size.dtype = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
return true;
}

size_t split_dim;


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h View File

@@ -26,6 +26,7 @@

namespace mindspore {
namespace kernel {
constexpr float ReLU6_UP_TURNING_POINT = 5.999999;
template <typename T>
class ActivationGradGpuKernel : public GpuKernel {
public:
@@ -86,7 +87,7 @@ class ActivationGradGpuKernel : public GpuKernel {
}
CheckTensorSize({input_shape});
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? ReLU6_UP_TURNING_POINT : 0.0;
if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef),


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 10;
template <typename T>
class AdamGpuKernel : public GpuKernel {
public:
@@ -63,8 +64,8 @@ class AdamGpuKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 10) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 10 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs " << INPUT_NUM << " inputs.";
return false;
}



+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 9;
template <typename T>
class AdamWeightDecayGpuKernel : public GpuKernel {
public:
@@ -61,8 +62,8 @@ class AdamWeightDecayGpuKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 9) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 9 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs " << INPUT_NUM << " inputs.";
return false;
}



+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h View File

@@ -65,6 +65,7 @@ class AdaptiveAvgPool2DKernel : public GpuKernel {
output_width = static_cast<uint>(shape_addr[0]);
} else {
MS_LOG(ERROR) << "Input Error.";
return false;
}

size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);


+ 8
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h View File

@@ -26,6 +26,8 @@

namespace mindspore {
namespace kernel {
constexpr size_t CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM = 6;
constexpr size_t NO_CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM = 5;
template <typename T>
class BatchNormGpuKernel : public GpuKernel {
public:
@@ -103,12 +105,14 @@ class BatchNormGpuKernel : public GpuKernel {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
if (input_num != 6) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 6";
if (input_num != CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be "
<< CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM;
}
} else {
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5";
if (input_num != NO_CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be "
<< NO_CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION_INPUT_NUM;
}
}



+ 8
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h View File

@@ -28,6 +28,8 @@

namespace mindspore {
namespace kernel {
constexpr size_t CUDNN_BATCHNORM_OPS_BN_INPUT_NUM = 6;
constexpr size_t NO_CUDNN_BATCHNORM_OPS_BN_INPUT_NUM = 8;
template <typename T>
class BatchNormGradGpuKernel : public GpuKernel {
public:
@@ -129,12 +131,14 @@ class BatchNormGradGpuKernel : public GpuKernel {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN) {
if (input_num != 6) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 6";
if (input_num != CUDNN_BATCHNORM_OPS_BN_INPUT_NUM) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be "
<< CUDNN_BATCHNORM_OPS_BN_INPUT_NUM;
}
} else {
if (input_num != 8) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 8";
if (input_num != NO_CUDNN_BATCHNORM_OPS_BN_INPUT_NUM) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be "
<< NO_CUDNN_BATCHNORM_OPS_BN_INPUT_NUM;
}
}



+ 17
- 15
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h View File

@@ -28,6 +28,7 @@

namespace mindspore {
namespace kernel {
#define NBDIMS 4
template <typename T>
class Conv2dGpuFwdKernel : public GpuKernel {
public:
@@ -126,18 +127,19 @@ class Conv2dGpuFwdKernel : public GpuKernel {
pad_width_ = pad_list[2] + pad_list[3];
pad_top_ = pad_list[0];
pad_left_ = pad_list[2];
int dimA[4];
int strideApadded[4];
const int nbDims = 4;
int dimA[NBDIMS];
int strideApadded[NBDIMS];
if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
SetDimA(padded_shape, dimA, nbDims, data_format_);
SetStrideA(padded_shape, strideApadded, nbDims, data_format_);
} else if (data_format_ == kOpFormat_NHWC) {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
SetDimA(padded_shape, dimA, nbDims, data_format_);
SetStrideA(padded_shape, strideApadded, nbDims, data_format_);
}
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded),
@@ -288,17 +290,17 @@ class Conv2dGpuFwdKernel : public GpuKernel {
void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &output_shape) {
const int nbDims = 4;
int dimA[4];
int strideAin[4];
int dimAout[4];
int strideAout[4];
SetDimA(in_shape, dimA, 4, data_format_);
SetStrideA(in_shape, strideAin, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
int dimA[NBDIMS];
int strideAin[NBDIMS];
int dimAout[NBDIMS];
int strideAout[NBDIMS];
SetDimA(in_shape, dimA, nbDims, data_format_);
SetStrideA(in_shape, strideAin, nbDims, data_format_);
SetDimA(output_shape, dimAout, nbDims, data_format_);
SetStrideA(output_shape, strideAout, nbDims, data_format_);
int filterDimA[4];
// OHWI for NHWC; OIHW for NCHW
SetDimA(filter_shape, filterDimA, 4, data_format_);
SetDimA(filter_shape, filterDimA, nbDims, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
"cudnnSetTensor4dDescriptor failed");


+ 19
- 17
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h View File

@@ -28,6 +28,7 @@

namespace mindspore {
namespace kernel {
#define NBDIMS 4
template <typename T>
class ConvGradFilterGpuBkwKernel : public GpuKernel {
public:
@@ -156,21 +157,22 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
use_pad_ = false;
}
int dimA[4];
int strideApadded[4];
const int nbDims = 4;
int dimA[NBDIMS];
int strideApadded[NBDIMS];
if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
SetDimA(padded_shape, dimA, nbDims, data_format_);
SetStrideA(padded_shape, strideApadded, nbDims, data_format_);
} else if (data_format_ == kOpFormat_NHWC) {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
SetDimA(padded_shape, dimA, nbDims, data_format_);
SetStrideA(padded_shape, strideApadded, nbDims, data_format_);
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded),
kernel_node_, cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, nbDims, dimA, strideApadded),
"cudnnSetTensor4dDescriptor failed");
padA[0] = 0;
padA[1] = 0;
@@ -311,17 +313,17 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape,
const std::vector<size_t> &in_shape) {
const int nbDims = 4;
int dimA[4];
int strideAin[4];
int dimAdy[4];
int strideAdy[4];
SetDimA(in_shape, dimA, 4, data_format_);
SetStrideA(in_shape, strideAin, 4, data_format_);
SetDimA(dy_shape, dimAdy, 4, data_format_);
SetStrideA(dy_shape, strideAdy, 4, data_format_);
int dimA[NBDIMS];
int strideAin[NBDIMS];
int dimAdy[NBDIMS];
int strideAdy[NBDIMS];
SetDimA(in_shape, dimA, nbDims, data_format_);
SetStrideA(in_shape, strideAin, nbDims, data_format_);
SetDimA(dy_shape, dimAdy, nbDims, data_format_);
SetStrideA(dy_shape, strideAdy, nbDims, data_format_);
// filter shape relued by format_attr_. In native mode it's OHWI. In transpose mode it's OIHW.
int filterDimA[4];
SetDimA(filter_shape, filterDimA, 4, format_attr_);
int filterDimA[NBDIMS];
SetDimA(filter_shape, filterDimA, nbDims, format_attr_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 8;
template <typename T>
class FtrlGpuKernel : public GpuKernel {
public:
@@ -59,8 +60,8 @@ class FtrlGpuKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs " << INPUT_NUM << " inputs.";
return false;
}



+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 6;
template <typename T, typename S>
class FusedScaleMomentumGpuKernel : public GpuKernel {
public:
@@ -47,8 +48,8 @@ class FusedScaleMomentumGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs " << INPUT_NUM << " inputs.";
return false;
}



+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_momentum_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 6;
template <typename T, typename S>
class FusedWeightDecayMomentumGpuKernel : public GpuKernel {
public:
@@ -47,8 +48,8 @@ class FusedWeightDecayMomentumGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs " << INPUT_NUM << " inputs.";
return false;
}



+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 7;
template <typename T, typename S>
class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel {
public:
@@ -48,8 +49,8 @@ class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 7) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 7 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs " << INPUT_NUM << " inputs.";
return false;
}



+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h View File

@@ -112,7 +112,7 @@ class L2NormalizeGpuKernel : public GpuKernel {
}
CheckTensorSize({inputA_shape, output_shape});
if (inputA_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
}

std::vector<size_t> outputC_shape = output_shape;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h View File

@@ -122,7 +122,7 @@ class L2NormalizeGradGpuKernel : public GpuKernel {
return false;
}
if (input_shape_list_[0].size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than " << MAX_DIMS;
}
return true;
}
@@ -180,7 +180,7 @@ class L2NormalizeGradGpuKernel : public GpuKernel {
bool CheckIONumber(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != INPUT_SIZE) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs 3 inputs.";
MS_LOG(ERROR) << "Input number is " << input_num << ", but l2normalize op needs " << INPUT_SIZE << " inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 5;
template <typename T, typename S, typename G>
class MomentumGpuKernel : public GpuKernel {
public:
@@ -51,8 +52,8 @@ class MomentumGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs " << INPUT_NUM << " inputs.";
return false;
}
use_nesterov_ = GetAttr<bool>(kernel_node, "use_nesterov");


+ 13
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h View File

@@ -27,6 +27,8 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 3;
#define NBDIMS 4
template <typename T>
class PoolingGradGpuKernel : public GpuKernel {
public:
@@ -116,14 +118,14 @@ class PoolingGradGpuKernel : public GpuKernel {
return false;
}
const int nbDims = 4;
int dimA[4];
int strideAin[4];
int dimAy[4];
int strideAiny[4];
int dimAdy[4];
int strideAdy[4];
int dimAout[4];
int strideAout[4];
int dimA[NBDIMS];
int strideAin[NBDIMS];
int dimAy[NBDIMS];
int strideAiny[NBDIMS];
int dimAdy[NBDIMS];
int strideAdy[NBDIMS];
int dimAout[NBDIMS];
int strideAout[NBDIMS];
if (!InitShape(kernel_node, dimA, strideAin, dimAy, strideAiny, dimAdy, strideAdy, dimAout, strideAout, nbDims)) {
return true;
}
@@ -198,8 +200,9 @@ class PoolingGradGpuKernel : public GpuKernel {
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs " << INPUT_NUM
<< " inputs.";
return false;
}
return true;


+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h View File

@@ -57,7 +57,8 @@ class PReLUGpuKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuFwdKernel input is null.";
MS_LOG(ERROR) << "PReLUGpuFwdKernel input is null.";
return false;
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
@@ -68,7 +69,8 @@ class PReLUGpuKernel : public GpuKernel {
auto weight_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(weight_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuFwdKernel weight is null.";
MS_LOG(ERROR) << "PReLUGpuFwdKernel weight is null.";
return false;
}
size = 1;
for (size_t i = 0; i < weight_shape.size(); i++) {


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h View File

@@ -56,7 +56,8 @@ class ReLUGpuFwdKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null.";
MS_LOG(ERROR) << "ReLUGpuFwdKernel input is null.";
return false;
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_gpu_kernel.h View File

@@ -60,7 +60,8 @@ class ReluGradGpuFwdKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ActivationGradGpuKernel input is null.";
MS_LOG(ERROR) << "ActivationGradGpuKernel input is null.";
return false;
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {


+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h View File

@@ -24,6 +24,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 2;
template <typename T, typename S>
class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel {
public:
@@ -48,8 +49,9 @@ class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs " << INPUT_NUM
<< " inputs.";
return false;
}
logits_size_ = sizeof(T);


+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_apply_proximal_adagrad_kernel.h View File

@@ -29,6 +29,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 7;
template <typename T>
class SparseApplyProximalAdagradKernel : public GpuKernel {
public:
@@ -59,8 +60,9 @@ class SparseApplyProximalAdagradKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 7) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but SparseApplyProximalAdagrad needs 7 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but SparseApplyProximalAdagrad needs " << INPUT_NUM
<< " inputs.";
return false;
}



+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.h View File

@@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 5;
template <typename T, typename S>
class SparseFtrlGpuKernel : public GpuKernel {
public:
@@ -63,8 +64,8 @@ class SparseFtrlGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but sparse ftrl needs 5 inputs.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but sparse ftrl needs " << INPUT_NUM << " inputs.";
return false;
}



+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h View File

@@ -24,6 +24,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 8;
template <typename T>
class BatchNormFold2GpuKernel : public GpuKernel {
public:
@@ -68,8 +69,9 @@ class BatchNormFold2GpuKernel : public GpuKernel {
InitResource();

size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs " << INPUT_NUM
<< " inputs.";
return false;
}



+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h View File

@@ -24,6 +24,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 8;
template <typename T>
class BatchNormFold2GradGpuKernel : public GpuKernel {
public:
@@ -96,8 +97,9 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
kernel_node_ = kernel_node;
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs " << INPUT_NUM
<< " inputs.";
return false;
}



+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h View File

@@ -24,6 +24,7 @@

namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 6;
template <typename T>
class BatchNormFoldGradGpuKernel : public GpuKernel {
public:
@@ -99,8 +100,9 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input.";
if (input_num != INPUT_NUM) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs " << INPUT_NUM
<< " inputs.";
return false;
}



Loading…
Cancel
Save