|
|
|
@@ -106,93 +106,67 @@ int CalNewCnodeBias(const AnfNodePtr &add_weight_node, const CNodePtr &matmul_cn |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
VectorRef MatMulAddFusion::DefineMatmulAddFusionPattern() const { |
|
|
|
auto is_matmul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>); |
|
|
|
MS_CHECK_TRUE_RET(is_matmul != nullptr, {}); |
|
|
|
auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); |
|
|
|
MS_CHECK_TRUE_RET(is_add != nullptr, {}); |
|
|
|
auto is_seq_var = std::make_shared<SeqVar>(); |
|
|
|
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {}); |
|
|
|
return VectorRef({is_add, is_matmul, is_seq_var}); |
|
|
|
} |
|
|
|
|
|
|
|
VectorRef MatMulAddFusion::DefineMatmulBiasAddPattern() const { |
|
|
|
auto is_matmul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>); |
|
|
|
MS_CHECK_TRUE_RET(is_matmul != nullptr, {}); |
|
|
|
auto is_bias_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBiasAdd>); |
|
|
|
MS_CHECK_TRUE_RET(is_bias_add != nullptr, {}); |
|
|
|
auto is_seq_var = std::make_shared<SeqVar>(); |
|
|
|
MS_CHECK_TRUE_RET(is_seq_var != nullptr, {}); |
|
|
|
return VectorRef({is_bias_add, is_matmul, is_seq_var}); |
|
|
|
} |
|
|
|
|
|
|
|
std::unordered_map<std::string, VectorRef> MatMulAddFusion::DefinePatterns() const { |
|
|
|
std::unordered_map<std::string, VectorRef> patterns; |
|
|
|
patterns["MatmulAddFusionPatternName"] = DefineMatmulAddFusionPattern(); |
|
|
|
patterns["MatmulBiasAddPatternName"] = DefineMatmulBiasAddPattern(); |
|
|
|
return patterns; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr MatMulAddFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, |
|
|
|
const AnfNodePtr &node, const EquivPtr &equiv) const { |
|
|
|
if (func_graph == nullptr || node == nullptr) { |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto add_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_CHECK_TRUE_RET(add_cnode != nullptr, nullptr); |
|
|
|
if (IsMarkedTrainOp(add_cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
for (auto &node : node_list) { |
|
|
|
MS_CHECK_TRUE_RET(node != nullptr, false); |
|
|
|
if (!utils::isa<CNode>(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto add_cnode = node->cast<CNodePtr>(); |
|
|
|
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsMarkedTrainOp(add_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t index = 0; |
|
|
|
|
|
|
|
size_t index = 0; |
|
|
|
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>(); |
|
|
|
MS_ASSERT(matmul_cnode != nullptr); |
|
|
|
if (IsMarkedTrainOp(matmul_cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (!CheckAndGetCnodeIndex(add_cnode, &index, prim::kPrimMatMulFusion)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsMultiOutputTensors(func_graph, matmul_cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto matmul_cnode = add_cnode->input(index)->cast<CNodePtr>(); |
|
|
|
MS_CHECK_TRUE_RET(matmul_cnode != nullptr, false); |
|
|
|
if (IsMarkedTrainOp(matmul_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!IsPrimitiveProper(add_cnode, matmul_cnode, index)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (IsMultiOutputTensors(func_graph, matmul_cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto manager = func_graph->manager(); |
|
|
|
auto add_param_node = add_cnode->input(kInputSizeThree - index); |
|
|
|
MS_CHECK_TRUE_RET(manager != nullptr, nullptr); |
|
|
|
if (matmul_cnode->size() == kInputSizeThree) { |
|
|
|
manager->AddEdge(matmul_cnode, add_param_node); |
|
|
|
} else if (matmul_cnode->size() == kInputSizeFour) { |
|
|
|
if (CalNewCnodeBias(add_param_node, matmul_cnode) != RET_OK) { |
|
|
|
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with " |
|
|
|
<< matmul_cnode->fullname_with_scope(); |
|
|
|
return nullptr; |
|
|
|
if (!IsPrimitiveProper(add_cnode, matmul_cnode, index)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_CHECK_TRUE_RET(manager != nullptr, false); |
|
|
|
auto add_param_node = add_cnode->input(kInputSizeThree - index); |
|
|
|
if (matmul_cnode->size() == kInputSizeThree) { |
|
|
|
manager->AddEdge(matmul_cnode, add_param_node); |
|
|
|
} else if (matmul_cnode->size() == kInputSizeFour) { |
|
|
|
if (CalNewCnodeBias(add_param_node, matmul_cnode) != RET_OK) { |
|
|
|
MS_LOG(INFO) << add_cnode->fullname_with_scope() << " failed to fusion with " |
|
|
|
<< matmul_cnode->fullname_with_scope(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (CheckPrimitiveType(node, prim::kPrimAddFusion)) { |
|
|
|
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0)); |
|
|
|
MS_CHECK_TRUE_RET(add_primc != nullptr, nullptr); |
|
|
|
if (add_primc->GetAttr(ops::kActivationType) != nullptr && |
|
|
|
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) { |
|
|
|
auto matmul_primc = GetValueNode<std::shared_ptr<ops::MatMulFusion>>(matmul_cnode->input(0)); |
|
|
|
MS_CHECK_TRUE_RET(matmul_primc != nullptr, nullptr); |
|
|
|
matmul_primc->set_activation_type(add_primc->get_activation_type()); |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimAddFusion)) { |
|
|
|
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0)); |
|
|
|
MS_CHECK_TRUE_RET(add_primc != nullptr, false); |
|
|
|
if (add_primc->GetAttr(ops::kActivationType) != nullptr && |
|
|
|
add_primc->get_activation_type() != ActivationType::NO_ACTIVATION) { |
|
|
|
auto matmul_primc = GetValueNode<std::shared_ptr<ops::MatMulFusion>>(matmul_cnode->input(0)); |
|
|
|
MS_CHECK_TRUE_RET(matmul_primc != nullptr, false); |
|
|
|
matmul_primc->set_activation_type(add_primc->get_activation_type()); |
|
|
|
} |
|
|
|
} |
|
|
|
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); |
|
|
|
(void)manager->Replace(node, matmul_cnode); |
|
|
|
} |
|
|
|
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); |
|
|
|
(void)manager->Replace(node, matmul_cnode); |
|
|
|
return nullptr; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |