|
|
@@ -105,6 +105,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp |
|
|
|
|
|
|
|
|
std::vector<int> a_shape = input0->shape(); |
|
|
std::vector<int> a_shape = input0->shape(); |
|
|
std::vector<int> b_shape = input1->shape(); |
|
|
std::vector<int> b_shape = input1->shape(); |
|
|
|
|
|
|
|
|
|
|
|
if (a_shape.size() == 4 && a_shape[2] == 1 && a_shape[3] == 1) { |
|
|
|
|
|
a_shape.resize(2); |
|
|
|
|
|
input0->set_shape(a_shape); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (a_shape.size() < 2 || b_shape.size() < 2) { |
|
|
if (a_shape.size() < 2 || b_shape.size() < 2) { |
|
|
MS_LOG(ERROR) << "inputs shape is invalid"; |
|
|
MS_LOG(ERROR) << "inputs shape is invalid"; |
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|