Browse Source

fix_matmul_infer

r1.7
sunsuodong 4 years ago
parent
commit
dc3c0e0b86
1 changed files with 12 additions and 1 deletions
  1. +12
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/matmul_infer.c

+ 12
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/matmul_infer.c View File

@@ -50,6 +50,17 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
return NNACL_OK;
}

int CheckMatMulBias(int *shape, size_t dim_size) {
if (dim_size > 1) {
for (size_t i = 0; i < dim_size - 1; i++) {
if (shape[i] != DIMENSION_1D) {
return NNACL_ERR;
}
}
}
return NNACL_OK;
}

int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
TensorC *input0 = (TensorC *)inputs[0];
@@ -67,7 +78,7 @@ int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs
if (inputs_size == kInputSize2) {
TensorC *bias = (TensorC *)inputs[2];
ShapeSet(bias_shape, &bias_shape_size, bias->shape_, bias->shape_size_);
MS_CHECK_TRUE_RET(bias_shape_size <= DIMENSION_1D, NNACL_ERR);
MS_CHECK_TRUE_RET(CheckMatMulBias(bias_shape, bias_shape_size) == NNACL_OK, NNACL_ERR);
}

if (a_shape_size == COMM_SHAPE_SIZE && a_shape[THIRD_INPUT] == 1 && a_shape[FOURTH_INPUT] == 1) {


Loading…
Cancel
Save