| @@ -338,32 +338,50 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | ||||
| MS_EXCEPTION_IF_NULL(y); | MS_EXCEPTION_IF_NULL(y); | ||||
| MS_EXCEPTION_IF_NULL(y->shape()); | MS_EXCEPTION_IF_NULL(y->shape()); | ||||
| if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { | |||||
| auto x_shp = x->shape()->shape(); | |||||
| auto y_shp = y->shape()->shape(); | |||||
| if (x_shp.size() != 2 || y_shp.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; | MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; | ||||
| } | } | ||||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||||
| bool TA = GetValue<bool>(TAptr); | |||||
| bool TB = GetValue<bool>(TBptr); | |||||
| ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); | |||||
| bool transpose_a = GetValue<bool>(transpose_a_ptr); | |||||
| bool transpose_b = GetValue<bool>(transpose_b_ptr); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | ShapeVector x_min_shape = x->shape()->min_shape(); | ||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | ShapeVector x_max_shape = x->shape()->max_shape(); | ||||
| ShapeVector y_min_shape = y->shape()->min_shape(); | ShapeVector y_min_shape = y->shape()->min_shape(); | ||||
| ShapeVector y_max_shape = y->shape()->max_shape(); | ShapeVector y_max_shape = y->shape()->max_shape(); | ||||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||||
| (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); | |||||
| // Additional check for dynamic shape | |||||
| // Last infer will be real shape values | |||||
| bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||||
| bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||||
| if (x_not_dyn && y_not_dyn) { | |||||
| auto x_col = x_shp[(transpose_a ? 0 : 1)]; | |||||
| auto y_row = y_shp[(transpose_b ? 1 : 0)]; | |||||
| if (x_col != y_row) { | |||||
| MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row | |||||
| << ". In MatMul x_col and y_row should be equal."; | |||||
| } | |||||
| } | |||||
| ShapeVector ret_shape; | ShapeVector ret_shape; | ||||
| ShapeVector ret_min_shape; | ShapeVector ret_min_shape; | ||||
| ShapeVector ret_max_shape; | ShapeVector ret_max_shape; | ||||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||||
| output.push_back(xshp[(TA ? 1 : 0)]); | |||||
| output.push_back(yshp[(TB ? 0 : 1)]); | |||||
| auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, | |||||
| const ShapeVector yshp) -> void { | |||||
| output.push_back(xshp[(transpose_a ? 1 : 0)]); | |||||
| output.push_back(yshp[(transpose_b ? 0 : 1)]); | |||||
| return; | return; | ||||
| }; | }; | ||||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||||
| make_shape(ret_shape, x_shp, y_shp); | |||||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | make_shape(ret_min_shape, x_min_shape, y_min_shape); | ||||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | make_shape(ret_max_shape, x_max_shape, y_max_shape); | ||||
| return std::make_shared<AbstractTensor>(x->element(), | |||||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| TypePtr x_type = x->element()->GetTypeTrack(); | |||||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||||
| x_type = kInt32; | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| } | } | ||||
| AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -376,24 +394,40 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP | |||||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | ||||
| MS_EXCEPTION_IF_NULL(y); | MS_EXCEPTION_IF_NULL(y); | ||||
| MS_EXCEPTION_IF_NULL(y->shape()); | MS_EXCEPTION_IF_NULL(y->shape()); | ||||
| if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { | |||||
| auto x_shp = x->shape()->shape(); | |||||
| auto y_shp = y->shape()->shape(); | |||||
| if (x_shp.size() != y_shp.size() || x_shp.size() < 3) { | |||||
| MS_LOG(EXCEPTION) | MS_LOG(EXCEPTION) | ||||
| << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; | << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; | ||||
| } | } | ||||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||||
| bool TA = GetValue<bool>(TAptr); | |||||
| bool TB = GetValue<bool>(TBptr); | |||||
| ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b"); | |||||
| bool transpose_a = GetValue<bool>(transpose_a_ptr); | |||||
| bool transpose_b = GetValue<bool>(transpose_b_ptr); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | ShapeVector x_min_shape = x->shape()->min_shape(); | ||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | ShapeVector x_max_shape = x->shape()->max_shape(); | ||||
| ShapeVector y_min_shape = y->shape()->min_shape(); | ShapeVector y_min_shape = y->shape()->min_shape(); | ||||
| ShapeVector y_max_shape = y->shape()->max_shape(); | ShapeVector y_max_shape = y->shape()->max_shape(); | ||||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||||
| (void)CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape); | |||||
| // Additional check for dynamic shape | |||||
| // Last infer will be real shape values | |||||
| bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||||
| bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); | |||||
| if (x_not_dyn && y_not_dyn) { | |||||
| size_t offset = x_shp.size() - 2; | |||||
| auto x_col = x_shp[offset + (transpose_a ? 0 : 1)]; | |||||
| auto y_row = y_shp[offset + (transpose_b ? 1 : 0)]; | |||||
| if (x_col != y_row) { | |||||
| MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row | |||||
| << ". In BatchMatMul x_col and y_row should be equal."; | |||||
| } | |||||
| } | |||||
| ShapeVector ret_shape; | ShapeVector ret_shape; | ||||
| ShapeVector ret_min_shape; | ShapeVector ret_min_shape; | ||||
| ShapeVector ret_max_shape; | ShapeVector ret_max_shape; | ||||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||||
| auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp, | |||||
| const ShapeVector yshp) -> void { | |||||
| for (size_t i = 0; i < xshp.size() - 2; i++) { | for (size_t i = 0; i < xshp.size() - 2; i++) { | ||||
| if (xshp[i] != yshp[i]) { | if (xshp[i] != yshp[i]) { | ||||
| if (xshp[i] > 0 && yshp[i] > 0) { | if (xshp[i] > 0 && yshp[i] > 0) { | ||||
| @@ -405,15 +439,18 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP | |||||
| } | } | ||||
| } | } | ||||
| size_t offset = xshp.size() - 2; | size_t offset = xshp.size() - 2; | ||||
| output.push_back(xshp[offset + (TA ? 1 : 0)]); | |||||
| output.push_back(yshp[offset + (TB ? 0 : 1)]); | |||||
| output.push_back(xshp[offset + (transpose_a ? 1 : 0)]); | |||||
| output.push_back(yshp[offset + (transpose_b ? 0 : 1)]); | |||||
| return; | return; | ||||
| }; | }; | ||||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||||
| make_shape(ret_shape, x_shp, y_shp); | |||||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | make_shape(ret_min_shape, x_min_shape, y_min_shape); | ||||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | make_shape(ret_max_shape, x_max_shape, y_max_shape); | ||||
| return std::make_shared<AbstractTensor>(x->element(), | |||||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| TypePtr x_type = x->element()->GetTypeTrack(); | |||||
| if (x_type->type_id() == TypeId::kNumberTypeInt8) { | |||||
| x_type = kInt32; | |||||
| } | |||||
| return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| } | } | ||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||