|
|
|
@@ -23,6 +23,49 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
constexpr auto kAttrTransposeX1 = "transpose_x1"; |
|
|
|
constexpr auto kAttrTransposeX2 = "transpose_x2"; |
|
|
|
|
|
|
|
struct WrongCase { |
|
|
|
std::vector<size_t> matmul_input0_shape; |
|
|
|
std::vector<size_t> matmul_input1_shape; |
|
|
|
std::vector<size_t> transpose_output_shape; |
|
|
|
bool transpose_x1; |
|
|
|
bool transpose_x2; |
|
|
|
}; |
|
|
|
|
|
|
|
bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) { |
|
|
|
std::vector<WrongCase> wrong_cases; |
|
|
|
|
|
|
|
// add wrong cases |
|
|
|
WrongCase wrong_case1; |
|
|
|
wrong_case1.matmul_input0_shape = {128, 1024}; |
|
|
|
wrong_case1.matmul_input1_shape = {1024, 1024}; |
|
|
|
wrong_case1.transpose_output_shape = {1, 16, 128, 64}; |
|
|
|
wrong_case1.transpose_x1 = false; |
|
|
|
wrong_case1.transpose_x2 = true; |
|
|
|
wrong_cases.push_back(std::move(wrong_case1)); |
|
|
|
|
|
|
|
// get node shape |
|
|
|
auto matmul_input0_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0); |
|
|
|
auto matmul_input1_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1); |
|
|
|
auto transpose_output_shape = common::AnfAlgo::GetOutputInferShape(confusion_transpose, 0); |
|
|
|
auto transpose_x1 = common::AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1); |
|
|
|
auto transpose_x2 = common::AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2); |
|
|
|
|
|
|
|
// check |
|
|
|
return std::any_of(wrong_cases.begin(), wrong_cases.end(), |
|
|
|
[matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1, |
|
|
|
transpose_x2](WrongCase wrong_case) { |
|
|
|
return wrong_case.matmul_input0_shape == matmul_input0_shape && |
|
|
|
wrong_case.matmul_input1_shape == matmul_input1_shape && |
|
|
|
wrong_case.transpose_output_shape == transpose_output_shape && |
|
|
|
wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2; |
|
|
|
}); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode, |
|
|
|
const session::KernelGraph & /* kernel_graph */, |
|
|
|
FusedNodeRecord *candidate_fusion) { |
|
|
|
@@ -32,6 +75,9 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode |
|
|
|
MS_EXCEPTION_IF_NULL(matmul); |
|
|
|
if (matmul->isa<CNode>() && (common::AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) || |
|
|
|
common::AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) { |
|
|
|
if (CheckWrongShape(matmul, cnode)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
mindspore::HashSet<AnfNodePtr> record{cnode, matmul}; |
|
|
|
candidate_fusion->push_back(record); |
|
|
|
SetRecordFusionId(record); |
|
|
|
|