diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 012e9ff932..a8d946c848 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -338,32 +338,50 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p auto y = CheckArg(op_name, args_spec_list, 1); MS_EXCEPTION_IF_NULL(y); MS_EXCEPTION_IF_NULL(y->shape()); - if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { + auto x_shp = x->shape()->shape(); + auto y_shp = y->shape()->shape(); + if (x_shp.size() != 2 || y_shp.size() != 2) { MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; } - ValuePtr TAptr = primitive->GetAttr("transpose_a"); - ValuePtr TBptr = primitive->GetAttr("transpose_b"); - bool TA = GetValue(TAptr); - bool TB = GetValue(TBptr); + ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); + ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); + bool transpose_a = GetValue(transpose_a_ptr); + bool transpose_b = GetValue(transpose_b_ptr); ShapeVector x_min_shape = x->shape()->min_shape(); ShapeVector x_max_shape = x->shape()->max_shape(); ShapeVector y_min_shape = y->shape()->min_shape(); ShapeVector y_max_shape = y->shape()->max_shape(); - (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); - (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); + (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); + (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); + // Additional check for dynamic shape + // Last infer will be real shape values + bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); + bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); + if (x_not_dyn && y_not_dyn) { + auto x_col = x_shp[(transpose_a ? 0 : 1)]; + auto y_row = y_shp[(transpose_b ? 1 : 0)]; + if (x_col != y_row) { + MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row + << ". In MatMul x_col and y_row should be equal."; + } + } ShapeVector ret_shape; ShapeVector ret_min_shape; ShapeVector ret_max_shape; - auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { - output.push_back(xshp[(TA ? 1 : 0)]); - output.push_back(yshp[(TB ? 0 : 1)]); + auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, + const ShapeVector yshp) -> void { + output.push_back(xshp[(transpose_a ? 1 : 0)]); + output.push_back(yshp[(transpose_b ? 0 : 1)]); return; }; - make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); + make_shape(ret_shape, x_shp, y_shp); make_shape(ret_min_shape, x_min_shape, y_min_shape); make_shape(ret_max_shape, x_max_shape, y_max_shape); - return std::make_shared(x->element(), - std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); + TypePtr x_type = x->element()->GetTypeTrack(); + if (x_type->type_id() == TypeId::kNumberTypeInt8) { + x_type = kInt32; + } + return std::make_shared(x_type, std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); } AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -376,24 +394,40 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP auto y = CheckArg(op_name, args_spec_list, 1); MS_EXCEPTION_IF_NULL(y); MS_EXCEPTION_IF_NULL(y->shape()); - if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { + auto x_shp = x->shape()->shape(); + auto y_shp = y->shape()->shape(); + if (x_shp.size() != y_shp.size() || x_shp.size() < 3) { MS_LOG(EXCEPTION) << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; } - ValuePtr TAptr = primitive->GetAttr("transpose_a"); - ValuePtr TBptr = primitive->GetAttr("transpose_b"); - bool TA = GetValue(TAptr); - bool TB = GetValue(TBptr); + ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); + ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); + bool transpose_a = GetValue(transpose_a_ptr); + bool transpose_b = GetValue(transpose_b_ptr); ShapeVector x_min_shape = x->shape()->min_shape(); ShapeVector x_max_shape = x->shape()->max_shape(); ShapeVector y_min_shape = y->shape()->min_shape(); ShapeVector y_max_shape = y->shape()->max_shape(); - (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); - (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); + (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); + (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); + // Additional check for dynamic shape + // Last infer will be real shape values + bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); + bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); + if (x_not_dyn && y_not_dyn) { + size_t offset = x_shp.size() - 2; + auto x_col = x_shp[offset + (transpose_a ? 0 : 1)]; + auto y_row = y_shp[offset + (transpose_b ? 1 : 0)]; + if (x_col != y_row) { + MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row + << ". In BatchMatMul x_col and y_row should be equal."; + } + } ShapeVector ret_shape; ShapeVector ret_min_shape; ShapeVector ret_max_shape; - auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { + auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, + const ShapeVector yshp) -> void { for (size_t i = 0; i < xshp.size() - 2; i++) { if (xshp[i] != yshp[i]) { if (xshp[i] > 0 && yshp[i] > 0) { @@ -405,15 +439,18 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP } } size_t offset = xshp.size() - 2; - output.push_back(xshp[offset + (TA ? 1 : 0)]); - output.push_back(yshp[offset + (TB ? 0 : 1)]); + output.push_back(xshp[offset + (transpose_a ? 1 : 0)]); + output.push_back(yshp[offset + (transpose_b ? 0 : 1)]); return; }; - make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); + make_shape(ret_shape, x_shp, y_shp); make_shape(ret_min_shape, x_min_shape, y_min_shape); make_shape(ret_max_shape, x_max_shape, y_max_shape); - return std::make_shared(x->element(), - std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); + TypePtr x_type = x->element()->GetTypeTrack(); + if (x_type->type_id() == TypeId::kNumberTypeInt8) { + x_type = kInt32; + } + return std::make_shared(x_type, std::make_shared(ret_shape, ret_min_shape, ret_max_shape)); } } // namespace abstract } // namespace mindspore