Browse Source

!30799 add case check for MatmulConfusionTranposeFusionPass

Merge pull request !30799 from yuchaojie/ub_fusion2
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
40c32fef1b
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 46 additions and 0 deletions
  1. +46
    -0
      mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc

+ 46
- 0
mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_confusiontranspose_fusion_pass.cc View File

@@ -23,6 +23,49 @@

namespace mindspore {
namespace opt {
namespace {
constexpr auto kAttrTransposeX1 = "transpose_x1";
constexpr auto kAttrTransposeX2 = "transpose_x2";

struct WrongCase {
std::vector<size_t> matmul_input0_shape;
std::vector<size_t> matmul_input1_shape;
std::vector<size_t> transpose_output_shape;
bool transpose_x1;
bool transpose_x2;
};

bool CheckWrongShape(const AnfNodePtr &matmul, const AnfNodePtr &confusion_transpose) {
std::vector<WrongCase> wrong_cases;

// add wrong cases
WrongCase wrong_case1;
wrong_case1.matmul_input0_shape = {128, 1024};
wrong_case1.matmul_input1_shape = {1024, 1024};
wrong_case1.transpose_output_shape = {1, 16, 128, 64};
wrong_case1.transpose_x1 = false;
wrong_case1.transpose_x2 = true;
wrong_cases.push_back(std::move(wrong_case1));

// get node shape
auto matmul_input0_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, 0);
auto matmul_input1_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, 1);
auto transpose_output_shape = common::AnfAlgo::GetOutputInferShape(confusion_transpose, 0);
auto transpose_x1 = common::AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX1);
auto transpose_x2 = common::AnfAlgo::GetBooleanAttr(matmul, kAttrTransposeX2);

// check
return std::any_of(wrong_cases.begin(), wrong_cases.end(),
[matmul_input0_shape, matmul_input1_shape, transpose_output_shape, transpose_x1,
transpose_x2](WrongCase wrong_case) {
return wrong_case.matmul_input0_shape == matmul_input0_shape &&
wrong_case.matmul_input1_shape == matmul_input1_shape &&
wrong_case.transpose_output_shape == transpose_output_shape &&
wrong_case.transpose_x1 == transpose_x1 && wrong_case.transpose_x2 == transpose_x2;
});
}
} // namespace

void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
const session::KernelGraph & /* kernel_graph */,
FusedNodeRecord *candidate_fusion) {
@@ -32,6 +75,9 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && (common::AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
common::AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
if (CheckWrongShape(matmul, cnode)) {
return;
}
mindspore::HashSet<AnfNodePtr> record{cnode, matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);


Loading…
Cancel
Save