Browse Source

Adapt matmul's input tensor shape of onnx model

tags/v1.1.0
zhanyuan 5 years ago
parent
commit
ec5026f3b2
1 changed files with 6 additions and 0 deletions
  1. +6
    -0
      mindspore/lite/src/ops/matmul.cc

+ 6
- 0
mindspore/lite/src/ops/matmul.cc View File

@@ -105,6 +105,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp

std::vector<int> a_shape = input0->shape();
std::vector<int> b_shape = input1->shape();

if (a_shape.size() == 4 && a_shape[2] == 1 && a_shape[3] == 1) {
a_shape.resize(2);
input0->set_shape(a_shape);
}

if (a_shape.size() < 2 || b_shape.size() < 2) {
MS_LOG(ERROR) << "inputs shape is invalid";
return RET_INPUT_TENSOR_ERROR;


Loading…
Cancel
Save