Browse Source

!29851 fix matmul infer div zero

Merge pull request !29851 from zhaodezan/master
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
6b982680a6
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 20 additions and 3 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c
  2. +14
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.c
  4. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/addn_infer.c View File

@@ -52,7 +52,7 @@ int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
return NNACL_ERR;
}
if (inputs[i]->shape_size_ == max_dims) {
for (int j = 0; j < max_dims; j++) {
for (size_t j = 0; j < max_dims; j++) {
if (inputs[i]->shape_[j] != inputs[max_dims_idx]->shape_[j] && inputs[i]->shape_[j] != 1 &&
inputs[max_dims_idx]->shape_[j] != 1) {
return NNACL_ERR;


+ 14
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c View File

@@ -388,6 +388,20 @@ int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_siz
return NNACL_OK;
}

int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
size_t outputs_size, OpParameter *parameter) {
int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2);
if (ret != NNACL_OK) {
return ret;
}
SetDataTypeFormat(outputs[0], inputs[0]);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
SetShapeTensor(outputs[0], inputs[0]);
return NNACL_OK;
}

int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.c View File

@@ -369,7 +369,7 @@ InferShape GetInferFunc(int prim_type) {
RegAllInferFunc3();
}
#endif
if (prim_type < PrimType_MAX) {
if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) {
return g_infer_func[prim_type];
} else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) {
return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin];
@@ -378,7 +378,7 @@ InferShape GetInferFunc(int prim_type) {
}

void RegInfer(int prim_type, InferShape func) {
if (prim_type < PrimType_MAX) {
if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) {
g_infer_func[prim_type] = func;
} else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) {
g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func;


+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c View File

@@ -31,6 +31,9 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) {
int min_value = MSMIN(a_shape[i], b_shape[i]);
int max_value = MSMAX(a_shape[i], b_shape[i]);
if (min_value == 0) {
return NNACL_ERR;
}
if (max_value % min_value != 0) {
return NNACL_INPUT_TENSOR_ERROR;
}


Loading…
Cancel
Save