Browse Source

restore the bug of matmul-add fusion

feature/build-system-rewrite
wangyanling 4 years ago
parent
commit
9b887d27d5
2 changed files with 56 additions and 89 deletions
  1. +53
    -79
      mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc
  2. +3
    -10
      mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.h

+ 53
- 79
mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.cc View File

@@ -106,93 +106,67 @@ int CalNewCnodeBias(const AnfNodePtr &add_weight_node, const CNodePtr &matmul_cn
}
} // namespace

VectorRef MatMulAddFusion::DefineMatmulAddFusionPattern() const {
auto is_matmul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto is_seq_var = std::make_shared<SeqVar>();
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
return VectorRef({is_add, is_matmul, is_seq_var});
}

VectorRef MatMulAddFusion::DefineMatmulBiasAddPattern() const {
auto is_matmul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>);
MS_CHECK_TRUE_RET(is_bias_add != nullptr, {});
auto is_seq_var = std::make_shared<SeqVar>();
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
return VectorRef({is_bias_add, is_matmul, is_seq_var});
}

std::unordered_map<std::string, VectorRef> MatMulAddFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
patterns["MatmulAddFusionPatternName"] = DefineMatmulAddFusionPattern();
patterns["MatmulBiasAddPatternName"] = DefineMatmulBiasAddPattern();
return patterns;
}

AnfNodePtr MatMulAddFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
const AnfNodePtr &node, const EquivPtr &equiv) const {
if (func_graph == nullptr || node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}

auto add_cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr);
if (IsMarkedTrainOp(add_cnode)) {
return nullptr;
}
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
return nullptr;
}
bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
MS_CHECK_TRUE_RET(node != nullptr, false);
if (!utils::isa<CNode>(node)) {
continue;
}
auto add_cnode = node->cast<CNodePtr>();
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
continue;
}
if (IsMarkedTrainOp(add_cnode)) {
continue;
}
size_t index = 0;

size_t index = 0;
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) {
return nullptr;
}
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(matmul_cnode != nullptr);
if (IsMarkedTrainOp(matmul_cnode)) {
return nullptr;
}
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) {
continue;
}

if (IsMultiOutputTensors(func_graph, matmul_cnode)) {
return nullptr;
}
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>();
MS_CHECK_TRUE_RET(matmul_cnode != nullptr, false);
if (IsMarkedTrainOp(matmul_cnode)) {
continue;
}

if (!IsPrimitiveProper(add_cnode, matmul_cnode, index)) {
return nullptr;
}
if (IsMultiOutputTensors(func_graph, matmul_cnode)) {
continue;
}

auto manager = func_graph->manager();
auto add_param_node = add_cnode->input(kInputSizeThree - index);
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
if (matmul_cnode->size() == kInputSizeThree) {
manager->AddEdge(matmul_cnode, add_param_node);
} else if (matmul_cnode->size() == kInputSizeFour) {
if (CalNewCnodeBias(add_param_node, matmul_cnode) != RET_OK) {
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with "
<< matmul_cnode->fullname_with_scope();
return nullptr;
if (!IsPrimitiveProper(add_cnode, matmul_cnode, index)) {
continue;
}
auto manager = func_graph->manager();
MS_CHECK_TRUE_RET(manager != nullptr, false);
auto add_param_node = add_cnode->input(kInputSizeThree - index);
if (matmul_cnode->size() == kInputSizeThree) {
manager->AddEdge(matmul_cnode, add_param_node);
} else if (matmul_cnode->size() == kInputSizeFour) {
if (CalNewCnodeBias(add_param_node, matmul_cnode) != RET_OK) {
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with "
<< matmul_cnode->fullname_with_scope();
return false;
}
}
}

if (CheckPrimitiveType(node, prim::kPrimAddFusion)) {
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
MS_CHECK_TRUE_RET(add_primc != nullptr, nullptr);
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
auto matmul_primc = GetValueNode<std::shared_ptr<ops::MatMulFusion>>(matmul_cnode->input(0));
MS_CHECK_TRUE_RET(matmul_primc != nullptr, nullptr);
matmul_primc->set_activation_type(add_primc->get_activation_type());
if (CheckPrimitiveType(node, prim::kPrimAddFusion)) {
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
MS_CHECK_TRUE_RET(add_primc != nullptr, false);
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) {
auto matmul_primc = GetValueNode<std::shared_ptr<ops::MatMulFusion>>(matmul_cnode->input(0));
MS_CHECK_TRUE_RET(matmul_primc != nullptr, false);
matmul_primc->set_activation_type(add_primc->get_activation_type());
}
}
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
(void)manager->Replace(node, matmul_cnode);
}
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
(void)manager->Replace(node, matmul_cnode);
return nullptr;
return false;
}
} // namespace opt
} // namespace mindspore

+ 3
- 10
mindspore/lite/tools/optimizer/fusion/matmul_add_fusion.h View File

@@ -24,18 +24,11 @@

namespace mindspore {
namespace opt {
class MatMulAddFusion : public MultiplePatternProcessPass {
class MatMulAddFusion : public Pass {
public:
explicit MatMulAddFusion(const std::string &name = "MatMulAddFusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {}
MatMulAddFusion() : Pass("MatMulAddFusion") {}
~MatMulAddFusion() override = default;

private:
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
VectorRef DefineMatmulAddFusionPattern() const;
VectorRef DefineMatmulBiasAddPattern() const;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &,
const EquivPtr &) const override;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace opt
} // namespace mindspore


Loading…
Cancel
Save