| @@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||||
| auto new_fg = BasicClone(fg); | auto new_fg = BasicClone(fg); | ||||
| cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg)); | cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg)); | ||||
| } | } | ||||
| auto origin_inputs = cnode->inputs(); | |||||
| bool optimize_depend = false; | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && | |||||
| origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { | |||||
| optimize_depend = true; | |||||
| } | |||||
| // if has multiple depends,only select first depend as parameter | // if has multiple depends,only select first depend as parameter | ||||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||||
| auto anf = cnode->inputs()[input_idx]; | |||||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | |||||
| auto anf = origin_inputs[input_idx]; | |||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| // anf has been created before | // anf has been created before | ||||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | ||||
| @@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||||
| (*other_graph_cnode)[anf] = new_parameter; | (*other_graph_cnode)[anf] = new_parameter; | ||||
| } | } | ||||
| continue; | continue; | ||||
| } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | |||||
| cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); | |||||
| continue; | |||||
| } else if (anf->isa<AnfNode>()) { | } else if (anf->isa<AnfNode>()) { | ||||
| *from_other_graph = true; | *from_other_graph = true; | ||||
| // the input node is a cnode from other graph | // the input node is a cnode from other graph | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/utils.h" | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| @@ -85,7 +86,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| if (lst.empty()) { | if (lst.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Input anf node list is empty"; | MS_LOG(EXCEPTION) << "Input anf node list is empty"; | ||||
| } | } | ||||
| auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { | auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { | ||||
| if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) { | if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) { | ||||
| eqv[a] = a; | eqv[a] = a; | ||||
| @@ -95,17 +95,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| eqv[a]->set_abstract(a->abstract()); | eqv[a]->set_abstract(a->abstract()); | ||||
| eqv[a]->set_kernel_info(a->kernel_info_ptr()); | eqv[a]->set_kernel_info(a->kernel_info_ptr()); | ||||
| } | } | ||||
| return eqv[a]; | return eqv[a]; | ||||
| }; | }; | ||||
| // Merge CNodes into a AnfGraph that represents a linear instruction segment | // Merge CNodes into a AnfGraph that represents a linear instruction segment | ||||
| for (auto n : lst) { | for (auto n : lst) { | ||||
| if (!n->isa<CNode>()) { | if (!n->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Inst is not CNode"; | MS_LOG(EXCEPTION) << "Inst is not CNode"; | ||||
| } | } | ||||
| auto &inps = n->cast<CNodePtr>()->inputs(); | auto &inps = n->cast<CNodePtr>()->inputs(); | ||||
| if (inps.empty()) { | if (inps.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Input is empty"; | MS_LOG(EXCEPTION) << "Input is empty"; | ||||
| } | } | ||||
| @@ -114,21 +111,22 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { | inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { | ||||
| MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode"; | MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode"; | ||||
| } | } | ||||
| auto fn = inps[0]; | auto fn = inps[0]; | ||||
| std::vector<AnfNodePtr> args{fn}; | std::vector<AnfNodePtr> args{fn}; | ||||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | |||||
| if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() && | |||||
| eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | |||||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||||
| } else { | |||||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | |||||
| } | |||||
| eqv[n] = fg->NewCNode(args); | eqv[n] = fg->NewCNode(args); | ||||
| eqv[n]->set_abstract(n->abstract()); | eqv[n]->set_abstract(n->abstract()); | ||||
| eqv[n]->set_kernel_info(n->kernel_info_ptr()); | eqv[n]->set_kernel_info(n->kernel_info_ptr()); | ||||
| } | } | ||||
| std::vector<AnfNodePtr> eqv_keys; | std::vector<AnfNodePtr> eqv_keys; | ||||
| (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), | (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), | ||||
| [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); | [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); | ||||
| auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); | auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); | ||||
| AnfNodePtr fg_output; | AnfNodePtr fg_output; | ||||
| if (outputs.size() > 1) { | if (outputs.size() > 1) { | ||||
| @@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||||
| } | } | ||||
| } | } | ||||
| bool IsGetItemNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| if (!IsValueNode<Primitive>(inputs[0])) { | |||||
| return true; | |||||
| } | |||||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]); | |||||
| return node_prim->name() == prim::kPrimTupleGetItem->name(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) { | |||||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||||
| std::vector<AnfNodePtr> result; | std::vector<AnfNodePtr> result; | ||||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | ||||
| std::map<AnfNodePtr, size_t> node_positions; | std::map<AnfNodePtr, size_t> node_positions; | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (IsGetItemNode(node)) { | |||||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto &inputs = cnode->inputs(); | auto &inputs = cnode->inputs(); | ||||
| @@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||||
| } | } | ||||
| } | } | ||||
| std::reverse(result.begin(), result.end()); | std::reverse(result.begin(), result.end()); | ||||
| return ReorderGetItemNode(result); | |||||
| return result; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||||
| VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto nodes = OptimizeGetItemOrder(input_nodes); | |||||
| VectorRef splits; | VectorRef splits; | ||||
| VectorRef split; | VectorRef split; | ||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| nodes = SplitSort(graph, default_target); | |||||
| } | |||||
| std::string last_target; | std::string last_target; | ||||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (IsCut(node)) { | if (IsCut(node)) { | ||||
| @@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||||
| return splits; | return splits; | ||||
| } | } | ||||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| nodes = SplitSort(graph, default_target); | |||||
| return SplitNodesWithTarget(nodes, graph); | |||||
| } | |||||
| VectorRef splits; | |||||
| VectorRef split; | |||||
| for (auto &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsCut(node)) { | |||||
| if (split.size() != 0) { | |||||
| splits.push_back(split); | |||||
| } | |||||
| splits.push_back(node); | |||||
| split.clear(); | |||||
| } else if (node->isa<CNode>()) { | |||||
| split.push_back(node); | |||||
| } | |||||
| } | |||||
| return splits; | |||||
| } | |||||
| // Push the value node on the stack. | // Push the value node on the stack. | ||||
| void CompileGraph::Push(const AnfNodePtr &node) { | void CompileGraph::Push(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -78,6 +78,7 @@ class CompileGraph { | |||||
| } | } | ||||
| private: | private: | ||||
| VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph); | |||||
| void PushParameters(const FuncGraphPtr &func_graph); | void PushParameters(const FuncGraphPtr &func_graph); | ||||
| bool SplitGraph(const FuncGraphPtr &func_graph); | bool SplitGraph(const FuncGraphPtr &func_graph); | ||||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | ||||