Browse Source

!7219 Adapt matmul's input tensor shape for onnx model

Merge pull request !7219 from zhanyuan/tmp
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f575d9e245
1 changed files with 6 additions and 0 deletions
  1. +6
    -0
      mindspore/lite/src/ops/matmul.cc

+ 6
- 0
mindspore/lite/src/ops/matmul.cc View File

@@ -105,6 +105,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp


std::vector<int> a_shape = input0->shape(); std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape(); std::vector<int> b_shape = input1->shape();

if (a_shape.size() == 4 && a_shape[2] == 1 && a_shape[3] == 1) {
a_shape.resize(2);
input0->set_shape(a_shape);
}

if (a_shape.size() < 2 || b_shape.size() < 2) { if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid"; MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR; return RET_INPUT_TENSOR_ERROR;


Loading…
Cancel
Save