/** * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "parser_graph_optimizer.h" #include "graph/op_types.h" #include "common/types_map.h" #include "common/util.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" #include "graph_to_function_def.h" #include "parser/common/acl_graph_parser_util.h" #include "register/op_registry.h" namespace ge { REGISTER_OPTYPE_DEFINE(TF_MAXIMUM_GRAD, "MaximumGrad"); REGISTER_OPTYPE_DEFINE(TF_MATMUL, "Matmul"); REGISTER_OPTYPE_DEFINE(TFRELU6, "Relu6"); REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); } // namespace ge namespace ge { namespace { const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; const char *const kShapeNodeType = "Shape"; const char *const kShapeNodeNamePrefix = "getnext_shape_"; const char *const kIteratorType = "Iterator"; const char *const kIteratorV2Type = "IteratorV2"; const char *const kGetNextType = "IteratorGetNext"; const char *const kDynGetNextType = "DynamicGetNext"; } // namespace Status ParserGraphOptimizer::FusionFmkop() { GELOGI("graph_optimizer.cpp && FustionFmkop()"); GE_CHECK_NOTNULL(graph_); std::unordered_map> node_cluser_Map; GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail."); GE_IF_BOOL_EXEC(node_cluser_Map.empty(), return SUCCESS); for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { GE_CHK_STATUS_RET(UpdateGraph(it->second), "fusion framework nodes failed. node:%s", (it->first).c_str()); } // fuse all fmkop and then delete node for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { for (auto node : it->second) { GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}), "Isolate removed node: %s, type: %s failed", node->GetName().c_str(), node->GetType().c_str()); GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node), "Remove node: %s, type: %s without relink failed", node->GetName().c_str(), node->GetType().c_str()); } } return SUCCESS; } Status ParserGraphOptimizer::MarkForFusion(unordered_map> &node_cluster_map) { GE_CHECK_NOTNULL(graph_); bool has_get_next = false; bool has_dyn_get_next = false; for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); if (node->GetType() == kDynGetNextType) { has_dyn_get_next = true; break; } GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); string type; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); if (type == kGetNextType) { has_get_next = true; break; } } return GetFusionCluster(has_get_next, has_dyn_get_next, node_cluster_map); } Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next, unordered_map> &node_cluster_map) { GE_CHECK_NOTNULL(graph_); for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) string type; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); if (type == kGetNextType) { vector temp_node_cluser; for (auto in_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); NodePtr src_node = peer_out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); temp_node_cluser.push_back(src_node); } temp_node_cluser.push_back(node); for (auto out_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_anchor); for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) { GE_CHECK_NOTNULL(in_anchor); NodePtr dst_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) && (dst_node->GetOpDesc()->GetType() == kShapeNodeType)) { temp_node_cluser.emplace_back(dst_node); } } } if (temp_node_cluser.size() > 1) { vector node_cluser; node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser; } temp_node_cluser.clear(); GELOGI("MarkForFusion, IteratorGetNext graph mark success."); } const bool dataset_init = (!has_get_next) && (!has_dyn_get_next) && ((type == kIteratorType) || (type == kIteratorV2Type)); if (dataset_init) { GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); GELOGI("MarkForFusion, Iterator init graph mark success."); } } return SUCCESS; } // find frameworkOP Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map> &node_cluser_Map) const { vector temp_node_cluser; for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); OpDescPtr temp_node_desc_ptr = node->GetOpDesc(); GE_CHECK_NOTNULL(temp_node_desc_ptr); GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue); if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE && (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) { temp_node_cluser.push_back(node); } else { if (temp_node_cluser.size() > 1) { vector node_cluser; node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; } temp_node_cluser.clear(); } } if (temp_node_cluser.size() > 1) { vector node_cluser; node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; } return SUCCESS; } Status CollectNodeFuncs(vector &nodes, FunctionDefLibrary *library) { for (auto node : nodes) { GE_CHECK_NOTNULL(node); OpDescPtr opDef = node->GetOpDesc(); string funcdefStr; ge::Buffer funcDefBytes; GE_IF_BOOL_EXEC( AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; GE_CHECK_NOTNULL(funcDefBytes.GetData()); string str(PtrToPtr(funcDefBytes.GetData()), funcDefBytes.GetSize()); GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); } return SUCCESS; } Status ParserGraphOptimizer::UpdateGraph(vector &nodes) { ComputeGraphPtr sub_graph = nullptr; GE_MAKE_SHARED( sub_graph = std::make_shared("subGraph"), sub_graph = nullptr; return PARAM_INVALID); unordered_map node_map; vector input_anchors; vector output_anchors; map> output_in_map; vector input_control_anchors; vector output_control_anchors; GE_CHK_STATUS_RET(InsertNode(sub_graph, nodes, input_anchors, output_anchors, output_in_map, input_control_anchors, output_control_anchors, node_map), "insert node to sub_graph failed."); GE_CHK_STATUS_RET(LinkInnerAnchor(node_map), "Link inner anchor failed."); std::unique_ptr node_def(new (std::nothrow) NodeDef()); // tensorflow NodeDef GE_CHECK_NOTNULL(node_def); std::unique_ptr func_def_lib(new (std::nothrow) FunctionDefLibrary()); GE_CHECK_NOTNULL(func_def_lib); // convert graph to FunctionDef if (nodes.size() == 0) { REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0"); GELOGE(FAILED, "node size must greater than 0 ."); return PARAM_INVALID; } GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), node_def.get(), input_anchors, output_anchors), "Build functiondef failed."); string nodefStr; string funcdefStr; GE_IF_BOOL_EXEC(!node_def->SerializeToString(&nodefStr), REPORT_CALL_ERROR("E19999", "Serialize nodedef to string failed"); GELOGE(PARAM_INVALID, "Serialize nodedef to string failed."); return PARAM_INVALID); GE_IF_BOOL_EXEC(!func_def_lib->SerializeToString(&funcdefStr), REPORT_CALL_ERROR("E19999", "Serialize func_def to string failed, "); GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); return PARAM_INVALID); if (nodes.size() == 0) { GELOGE(FAILED, "nodes is empty."); return PARAM_INVALID; } std::string fusion_op_name; for (auto node : nodes) { fusion_op_name += node->GetName(); } const uint32_t kFusionOpNameMaxLen = 1024; if (fusion_op_name.size() > kFusionOpNameMaxLen) { fusion_op_name = nodes[0]->GetName(); } OpDescPtr fusion_node_opdef = nullptr; GE_MAKE_SHARED( fusion_node_opdef = std::make_shared(fusion_op_name, nodes[0]->GetOpDesc()->GetType()), fusion_node_opdef = nullptr; return FAILED); std::string type = ""; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type)); (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); (void)AttrUtils::SetZeroCopyBytes( fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, Buffer::CopyFrom(PtrToPtr(funcdefStr.data()), funcdefStr.length())); (void)AttrUtils::SetZeroCopyBytes( fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, Buffer::CopyFrom(PtrToPtr(nodefStr.data()), nodefStr.length())); (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); // reconstruct fusion_node and edges GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef), "rebuild output edges to fusion node failed."); GE_CHK_STATUS_RET(RebuildInputAnchors(input_anchors, fusion_node_opdef), "rebuild input edges to fusion node failed."); NodePtr fusion_node = graph_->AddNode(fusion_node_opdef); // add Anchors GE_CHK_STATUS_RET(RebuildFusionNode(input_anchors, output_anchors, output_in_map, input_control_anchors, output_control_anchors, fusion_node), "rebuild node failed!"); return SUCCESS; } Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector &nodes, vector &input_anchors, vector &output_anchors, map> &output_in_map, vector &input_control_anchors, vector &output_control_anchors, unordered_map &node_map) { GE_CHECK_NOTNULL(sub_graph); for (NodePtr node : nodes) { GE_CHECK_NOTNULL(node); OpDescPtr op_def = node->GetOpDesc(); NodePtr new_node = sub_graph->AddNode(op_def); GE_CHECK_NOTNULL(new_node); node_map[node->GetName()] = new_node; // Input for (auto in_anchor : node->GetAllInDataAnchors()) { // data OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); GE_IF_BOOL_EXEC(iter == nodes.end(), input_anchors.emplace_back(in_anchor)); } // Output for (auto out_anchor : node->GetAllOutDataAnchors()) { bool hasOutNode = false; // data anchor for (auto peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_anchor->GetOwnerNode()); GE_IF_BOOL_EXEC(iter == nodes.end(), output_in_map[out_anchor].emplace_back(peer_in_anchor); hasOutNode = true); } GE_IF_BOOL_EXEC(hasOutNode, output_anchors.emplace_back(out_anchor)); } InControlAnchorPtr node_in_control = node->GetInControlAnchor(); GE_IF_BOOL_EXEC( node_in_control != nullptr, for (auto peer_out_anchor : node_in_control->GetPeerOutControlAnchors()) { vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); }); OutControlAnchorPtr node_out_control = node->GetOutControlAnchor(); GE_IF_BOOL_EXEC( node_out_control != nullptr, for (auto peer_in_control_anchor : node_out_control->GetPeerInControlAnchors()) { vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); }); } return SUCCESS; } Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map &node_map) const { for (auto node : graph_->GetDirectNode()) { GE_IF_BOOL_EXEC(node_map.count(node->GetName()) == 0, continue); NodePtr dst = node_map[node->GetName()]; for (auto in_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); GE_IF_BOOL_EXEC(node_map.count(peer_out_anchor->GetOwnerNode()->GetName()) == 0, continue); NodePtr src = node_map[peer_out_anchor->GetOwnerNode()->GetName()]; GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(src->GetOutDataAnchor(peer_out_anchor->GetIdx()), dst->GetInDataAnchor(in_anchor->GetIdx())) != GRAPH_SUCCESS, REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed", src->GetName().c_str(), src->GetType().c_str(), peer_out_anchor->GetIdx(), dst->GetName().c_str(), dst->GetType().c_str(), in_anchor->GetIdx()); GELOGE(FAILED, "LinkInnerAnchor Link data anchor failed, src node: %s, " "dst node: %s.", src->GetName().c_str(), dst->GetName().c_str()); return FAILED); } InControlAnchorPtr node_in_control = node->GetInControlAnchor(); GE_IF_BOOL_EXEC( node_in_control != nullptr, for (auto peer_out_ctl_anchor : node_in_control->GetPeerOutControlAnchors()) { GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; GE_IF_BOOL_EXEC( ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", src_ctrl->GetName().c_str(), src_ctrl->GetType().c_str(), dst->GetName().c_str(), dst->GetType().c_str()); GELOGE(FAILED, "LinkInnerAnchor Link control anchor failed, src node: " "%s, dst node: %s.", src_ctrl->GetName().c_str(), dst->GetName().c_str()); return FAILED); }); } return SUCCESS; } // rebuild output anchor Status ParserGraphOptimizer::RebuildOutputAnchors(vector &output_anchors, ge::OpDescPtr fusion_op_desc) { std::vector output_list; GE_CHECK_NOTNULL(fusion_op_desc); // create input desc for (auto out_anchor : output_anchors) { NodePtr src_node = out_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(src_node); GeTensorDesc src_out_desc = src_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); ge::DataType data_type = src_out_desc.GetDataType(); const std::map::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast(data_type)); GE_IF_BOOL_EXEC( iter == GE_TENSORFLOW_DATA_TYPE_MAP.cend(), REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported", data_type, out_anchor->GetIdx(), src_node->GetName().c_str(), src_node->GetName().c_str()); GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); return PARAM_INVALID); int32_t dtype = iter->second; output_list.push_back(static_cast(dtype)); GELOGI("FUNCDEF: output_list push_back %d.", dtype); } GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); return SUCCESS; } // rebuild input desc Status ParserGraphOptimizer::RebuildInputAnchors(vector &input_anchors, ge::OpDescPtr fusion_op_desc) { std::vector input_list; GE_CHECK_NOTNULL(fusion_op_desc); // add input desc for (auto in_anchor : input_anchors) { NodePtr dst_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); if (fusion_op_desc->AddInputDesc(*tensorDescPtr) != GRAPH_SUCCESS) { REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", fusion_op_desc->GetName().c_str(), fusion_op_desc->GetType().c_str()); GELOGE(FAILED, "Add fusion_op_desc AddInputDesc failed"); return FAILED; } ge::DataType data_type = tensorDescPtr->GetDataType(); const std::map::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(static_cast(data_type)); GE_IF_BOOL_EXEC( iter == GE_TENSORFLOW_DATA_TYPE_MAP.cend(), REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported", data_type, in_anchor->GetIdx(), dst_node->GetName().c_str(), dst_node->GetName().c_str()); GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); return PARAM_INVALID); int32_t dtype = iter->second; input_list.push_back(static_cast(dtype)); GELOGI("FUNCDEF: input_list push_back %d.", dtype); } GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); return SUCCESS; } Status ParserGraphOptimizer::RebuildFusionNode(vector &input_anchors, vector &output_anchors, map> &output_in_map, vector &input_control_anchors, vector &output_control_anchors, ge::NodePtr fusion_node) { GE_CHECK_NOTNULL(fusion_node); int32_t src_index = 0; for (auto out_anchor : output_anchors) { for (auto in_anchor : output_in_map[out_anchor]) { (void)in_anchor->Unlink(out_anchor); GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutDataAnchor(src_index), in_anchor), "Add anchor between fusion node and in anchor node!"); } src_index++; } src_index = 0; for (auto in_anchor : input_anchors) { OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); out_anchor->Unlink(in_anchor); GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_anchor, fusion_node->GetInDataAnchor(src_index)), "Add anchor between out anchor node and fusion node!"); src_index++; } for (auto out_control_anchor : output_control_anchors) { for (auto in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { in_control_anchor->Unlink(out_control_anchor); GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutControlAnchor(), in_control_anchor), "Add anchor between fusion node and in control anchor node!"); } } for (auto in_control_anchor : input_control_anchors) { for (auto out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { out_control_anchor->Unlink(in_control_anchor); GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_control_anchor, fusion_node->GetInControlAnchor()), "Add anchor between out control anchor node and fusion node!"); } } return SUCCESS; } } // namespace ge