|
|
|
@@ -116,8 +116,8 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp |
|
|
|
MS_LOG(ERROR) << "inputs shape is invalid"; |
|
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < a_shape.size() - 2; ++i) { |
|
|
|
if (a_shape[i] != b_shape[i]) { |
|
|
|
for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { |
|
|
|
if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { |
|
|
|
MS_LOG(ERROR) << "Op MatMul's dimensions must be equal"; |
|
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|
} |
|
|
|
|