| @@ -338,32 +338,50 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(y); | |||
| MS_EXCEPTION_IF_NULL(y->shape()); | |||
| if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { | |||
| auto x_shp = x->shape()->shape(); | |||
| auto y_shp = y->shape()->shape(); | |||
| if (x_shp.size() != 2 || y_shp.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; | |||
| } | |||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||
| bool TA = GetValue<bool>(TAptr); | |||
| bool TB = GetValue<bool>(TBptr); | |||
| ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); | |||
| ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); | |||
| bool transpose_a = GetValue<bool>(transpose_a_ptr); | |||
| bool transpose_b = GetValue<bool>(transpose_b_ptr); | |||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||
| ShapeVector y_min_shape = y->shape()->min_shape(); | |||
| ShapeVector y_max_shape = y->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||
| (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); | |||
| (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); | |||
| // Additional check for dynamic shape | |||
| // Last infer will be real shape values | |||
| bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||
| bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||
| if (x_not_dyn && y_not_dyn) { | |||
| auto x_col = x_shp[(transpose_a ? 0 : 1)]; | |||
| auto y_row = y_shp[(transpose_b ? 1 : 0)]; | |||
| if (x_col != y_row) { | |||
| MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row | |||
| << ". In MatMul x_col and y_row should be equal."; | |||
| } | |||
| } | |||
| ShapeVector ret_shape; | |||
| ShapeVector ret_min_shape; | |||
| ShapeVector ret_max_shape; | |||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||
| output.push_back(xshp[(TA ? 1 : 0)]); | |||
| output.push_back(yshp[(TB ? 0 : 1)]); | |||
| auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, | |||
| const ShapeVector yshp) -> void { | |||
| output.push_back(xshp[(transpose_a ? 1 : 0)]); | |||
| output.push_back(yshp[(transpose_b ? 0 : 1)]); | |||
| return; | |||
| }; | |||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||
| make_shape(ret_shape, x_shp, y_shp); | |||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | |||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | |||
| return std::make_shared<AbstractTensor>(x->element(), | |||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||
| TypePtr x_type = x->element()->GetTypeTrack(); | |||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||
| x_type = kInt32; | |||
| } | |||
| return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -376,24 +394,40 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP | |||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(y); | |||
| MS_EXCEPTION_IF_NULL(y->shape()); | |||
| if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { | |||
| auto x_shp = x->shape()->shape(); | |||
| auto y_shp = y->shape()->shape(); | |||
| if (x_shp.size() != y_shp.size() || x_shp.size() < 3) { | |||
| MS_LOG(EXCEPTION) | |||
| << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; | |||
| } | |||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||
| bool TA = GetValue<bool>(TAptr); | |||
| bool TB = GetValue<bool>(TBptr); | |||
| ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); | |||
| ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); | |||
| bool transpose_a = GetValue<bool>(transpose_a_ptr); | |||
| bool transpose_b = GetValue<bool>(transpose_b_ptr); | |||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||
| ShapeVector y_min_shape = y->shape()->min_shape(); | |||
| ShapeVector y_max_shape = y->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||
| (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); | |||
| (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); | |||
| // Additional check for dynamic shape | |||
| // Last infer will be real shape values | |||
| bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||
| bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||
| if (x_not_dyn && y_not_dyn) { | |||
| size_t offset = x_shp.size() - 2; | |||
| auto x_col = x_shp[offset + (transpose_a ? 0 : 1)]; | |||
| auto y_row = y_shp[offset + (transpose_b ? 1 : 0)]; | |||
| if (x_col != y_row) { | |||
| MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row | |||
| << ". In BatchMatMul x_col and y_row should be equal."; | |||
| } | |||
| } | |||
| ShapeVector ret_shape; | |||
| ShapeVector ret_min_shape; | |||
| ShapeVector ret_max_shape; | |||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||
| auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, | |||
| const ShapeVector yshp) -> void { | |||
| for (size_t i = 0; i < xshp.size() - 2; i++) { | |||
| if (xshp[i] != yshp[i]) { | |||
| if (xshp[i] > 0 && yshp[i] > 0) { | |||
| @@ -405,15 +439,18 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP | |||
| } | |||
| } | |||
| size_t offset = xshp.size() - 2; | |||
| output.push_back(xshp[offset + (TA ? 1 : 0)]); | |||
| output.push_back(yshp[offset + (TB ? 0 : 1)]); | |||
| output.push_back(xshp[offset + (transpose_a ? 1 : 0)]); | |||
| output.push_back(yshp[offset + (transpose_b ? 0 : 1)]); | |||
| return; | |||
| }; | |||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||
| make_shape(ret_shape, x_shp, y_shp); | |||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | |||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | |||
| return std::make_shared<AbstractTensor>(x->element(), | |||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||
| TypePtr x_type = x->element()->GetTypeTrack(); | |||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||
| x_type = kInt32; | |||
| } | |||
| return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||