|
|
|
@@ -339,32 +339,50 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
MS_EXCEPTION_IF_NULL(y); |
|
|
|
MS_EXCEPTION_IF_NULL(y->shape()); |
|
|
|
if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { |
|
|
|
auto x_shp = x->shape()->shape(); |
|
|
|
auto y_shp = y->shape()->shape(); |
|
|
|
if (x_shp.size() != 2 || y_shp.size() != 2) { |
|
|
|
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; |
|
|
|
} |
|
|
|
ValuePtr TAptr = primitive->GetAttr("transpose_a"); |
|
|
|
ValuePtr TBptr = primitive->GetAttr("transpose_b"); |
|
|
|
bool TA = GetValue<bool>(TAptr); |
|
|
|
bool TB = GetValue<bool>(TBptr); |
|
|
|
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); |
|
|
|
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); |
|
|
|
bool transpose_a = GetValue<bool>(transpose_a_ptr); |
|
|
|
bool transpose_b = GetValue<bool>(transpose_b_ptr); |
|
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
|
ShapeVector y_min_shape = y->shape()->min_shape(); |
|
|
|
ShapeVector y_max_shape = y->shape()->max_shape(); |
|
|
|
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); |
|
|
|
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); |
|
|
|
(void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); |
|
|
|
(void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); |
|
|
|
// Additional check for dynamic shape |
|
|
|
// Last infer will be real shape values |
|
|
|
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); |
|
|
|
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); |
|
|
|
if (x_not_dyn && y_not_dyn) { |
|
|
|
auto x_col = x_shp[(transpose_a ? 0 : 1)]; |
|
|
|
auto y_row = y_shp[(transpose_b ? 1 : 0)]; |
|
|
|
if (x_col != y_row) { |
|
|
|
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row |
|
|
|
<< ". In MatMul x_col and y_row should be equal."; |
|
|
|
} |
|
|
|
} |
|
|
|
ShapeVector ret_shape; |
|
|
|
ShapeVector ret_min_shape; |
|
|
|
ShapeVector ret_max_shape; |
|
|
|
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { |
|
|
|
output.push_back(xshp[(TA ? 1 : 0)]); |
|
|
|
output.push_back(yshp[(TB ? 0 : 1)]); |
|
|
|
auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, |
|
|
|
const ShapeVector yshp) -> void { |
|
|
|
output.push_back(xshp[(transpose_a ? 1 : 0)]); |
|
|
|
output.push_back(yshp[(transpose_b ? 0 : 1)]); |
|
|
|
return; |
|
|
|
}; |
|
|
|
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); |
|
|
|
make_shape(ret_shape, x_shp, y_shp); |
|
|
|
make_shape(ret_min_shape, x_min_shape, y_min_shape); |
|
|
|
make_shape(ret_max_shape, x_max_shape, y_max_shape); |
|
|
|
return std::make_shared<AbstractTensor>(x->element(), |
|
|
|
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); |
|
|
|
TypePtr x_type = x->element()->GetTypeTrack(); |
|
|
|
if (x_type->type_id() == TypeId::kNumberTypeInt8) { |
|
|
|
x_type = kInt32; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
@@ -377,24 +395,40 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP |
|
|
|
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
MS_EXCEPTION_IF_NULL(y); |
|
|
|
MS_EXCEPTION_IF_NULL(y->shape()); |
|
|
|
if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { |
|
|
|
auto x_shp = x->shape()->shape(); |
|
|
|
auto y_shp = y->shape()->shape(); |
|
|
|
if (x_shp.size() != y_shp.size() || x_shp.size() < 3) { |
|
|
|
MS_LOG(EXCEPTION) |
|
|
|
<< "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; |
|
|
|
} |
|
|
|
ValuePtr TAptr = primitive->GetAttr("transpose_a"); |
|
|
|
ValuePtr TBptr = primitive->GetAttr("transpose_b"); |
|
|
|
bool TA = GetValue<bool>(TAptr); |
|
|
|
bool TB = GetValue<bool>(TBptr); |
|
|
|
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); |
|
|
|
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); |
|
|
|
bool transpose_a = GetValue<bool>(transpose_a_ptr); |
|
|
|
bool transpose_b = GetValue<bool>(transpose_b_ptr); |
|
|
|
ShapeVector x_min_shape = x->shape()->min_shape(); |
|
|
|
ShapeVector x_max_shape = x->shape()->max_shape(); |
|
|
|
ShapeVector y_min_shape = y->shape()->min_shape(); |
|
|
|
ShapeVector y_max_shape = y->shape()->max_shape(); |
|
|
|
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); |
|
|
|
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); |
|
|
|
(void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); |
|
|
|
(void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); |
|
|
|
// Additional check for dynamic shape |
|
|
|
// Last infer will be real shape values |
|
|
|
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); |
|
|
|
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); |
|
|
|
if (x_not_dyn && y_not_dyn) { |
|
|
|
size_t offset = x_shp.size() - 2; |
|
|
|
auto x_col = x_shp[offset + (transpose_a ? 0 : 1)]; |
|
|
|
auto y_row = y_shp[offset + (transpose_b ? 1 : 0)]; |
|
|
|
if (x_col != y_row) { |
|
|
|
MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row |
|
|
|
<< ". In BatchMatMul x_col and y_row should be equal."; |
|
|
|
} |
|
|
|
} |
|
|
|
ShapeVector ret_shape; |
|
|
|
ShapeVector ret_min_shape; |
|
|
|
ShapeVector ret_max_shape; |
|
|
|
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { |
|
|
|
auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, |
|
|
|
const ShapeVector yshp) -> void { |
|
|
|
for (size_t i = 0; i < xshp.size() - 2; i++) { |
|
|
|
if (xshp[i] != yshp[i]) { |
|
|
|
if (xshp[i] > 0 && yshp[i] > 0) { |
|
|
|
@@ -406,15 +440,18 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP |
|
|
|
} |
|
|
|
} |
|
|
|
size_t offset = xshp.size() - 2; |
|
|
|
output.push_back(xshp[offset + (TA ? 1 : 0)]); |
|
|
|
output.push_back(yshp[offset + (TB ? 0 : 1)]); |
|
|
|
output.push_back(xshp[offset + (transpose_a ? 1 : 0)]); |
|
|
|
output.push_back(yshp[offset + (transpose_b ? 0 : 1)]); |
|
|
|
return; |
|
|
|
}; |
|
|
|
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); |
|
|
|
make_shape(ret_shape, x_shp, y_shp); |
|
|
|
make_shape(ret_min_shape, x_min_shape, y_min_shape); |
|
|
|
make_shape(ret_max_shape, x_max_shape, y_max_shape); |
|
|
|
return std::make_shared<AbstractTensor>(x->element(), |
|
|
|
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); |
|
|
|
TypePtr x_type = x->element()->GetTypeTrack(); |
|
|
|
if (x_type->type_id() == TypeId::kNumberTypeInt8) { |
|
|
|
x_type = kInt32; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); |
|
|
|
} |
|
|
|
} // namespace abstract |
|
|
|
} // namespace mindspore |