Browse Source

add fix for MatMul ops shape inferring

tags/v1.2.0-rc1
TFBunny 4 years ago
parent
commit
42aa743cfd
1 changed files with 63 additions and 26 deletions
  1. +63
    -26
      mindspore/core/abstract/prim_maths.cc

+ 63
- 26
mindspore/core/abstract/prim_maths.cc View File

@@ -338,32 +338,50 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) {
auto x_shp = x->shape()->shape();
auto y_shp = y->shape()->shape();
if (x_shp.size() != 2 || y_shp.size() != 2) {
MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
}
ValuePtr TAptr = primitive->GetAttr("transpose_a");
ValuePtr TBptr = primitive->GetAttr("transpose_b");
bool TA = GetValue<bool>(TAptr);
bool TB = GetValue<bool>(TBptr);
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
bool transpose_a = GetValue<bool>(transpose_a_ptr);
bool transpose_b = GetValue<bool>(transpose_b_ptr);
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector y_min_shape = y->shape()->min_shape();
ShapeVector y_max_shape = y->shape()->max_shape();
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
(void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
// Additional check for dynamic shape
// Last infer will be real shape values
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
if (x_not_dyn && y_not_dyn) {
auto x_col = x_shp[(transpose_a ? 0 : 1)];
auto y_row = y_shp[(transpose_b ? 1 : 0)];
if (x_col != y_row) {
MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
<< ". In MatMul x_col and y_row should be equal.";
}
}
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
output.push_back(xshp[(TA ? 1 : 0)]);
output.push_back(yshp[(TB ? 0 : 1)]);
auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
const ShapeVector yshp) -> void {
output.push_back(xshp[(transpose_a ? 1 : 0)]);
output.push_back(yshp[(transpose_b ? 0 : 1)]);
return;
};
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
make_shape(ret_shape, x_shp, y_shp);
make_shape(ret_min_shape, x_min_shape, y_min_shape);
make_shape(ret_max_shape, x_max_shape, y_max_shape);
return std::make_shared<AbstractTensor>(x->element(),
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
TypePtr x_type = x->element()->GetTypeTrack();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}

AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@@ -376,24 +394,40 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP
auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(y);
MS_EXCEPTION_IF_NULL(y->shape());
if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) {
auto x_shp = x->shape()->shape();
auto y_shp = y->shape()->shape();
if (x_shp.size() != y_shp.size() || x_shp.size() < 3) {
MS_LOG(EXCEPTION)
<< "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3.";
}
ValuePtr TAptr = primitive->GetAttr("transpose_a");
ValuePtr TBptr = primitive->GetAttr("transpose_b");
bool TA = GetValue<bool>(TAptr);
bool TB = GetValue<bool>(TBptr);
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
bool transpose_a = GetValue<bool>(transpose_a_ptr);
bool transpose_b = GetValue<bool>(transpose_b_ptr);
ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape();
ShapeVector y_min_shape = y->shape()->min_shape();
ShapeVector y_max_shape = y->shape()->max_shape();
(void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape);
(void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
(void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
// Additional check for dynamic shape
// Last infer will be real shape values
bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
if (x_not_dyn && y_not_dyn) {
size_t offset = x_shp.size() - 2;
auto x_col = x_shp[offset + (transpose_a ? 0 : 1)];
auto y_row = y_shp[offset + (transpose_b ? 1 : 0)];
if (x_col != y_row) {
MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
<< ". In BatchMatMul x_col and y_row should be equal.";
}
}
ShapeVector ret_shape;
ShapeVector ret_min_shape;
ShapeVector ret_max_shape;
auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void {
auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
const ShapeVector yshp) -> void {
for (size_t i = 0; i < xshp.size() - 2; i++) {
if (xshp[i] != yshp[i]) {
if (xshp[i] > 0 && yshp[i] > 0) {
@@ -405,15 +439,18 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP
}
}
size_t offset = xshp.size() - 2;
output.push_back(xshp[offset + (TA ? 1 : 0)]);
output.push_back(yshp[offset + (TB ? 0 : 1)]);
output.push_back(xshp[offset + (transpose_a ? 1 : 0)]);
output.push_back(yshp[offset + (transpose_b ? 0 : 1)]);
return;
};
make_shape(ret_shape, x->shape()->shape(), y->shape()->shape());
make_shape(ret_shape, x_shp, y_shp);
make_shape(ret_min_shape, x_min_shape, y_min_shape);
make_shape(ret_max_shape, x_max_shape, y_max_shape);
return std::make_shared<AbstractTensor>(x->element(),
std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
TypePtr x_type = x->element()->GetTypeTrack();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
}
} // namespace abstract
} // namespace mindspore

Loading…
Cancel
Save