| @@ -386,9 +386,15 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||
| auto new_fg = BasicClone(fg); | |||
| cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg)); | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| bool optimize_depend = false; | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && | |||
| origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) { | |||
| optimize_depend = true; | |||
| } | |||
| // if has multiple depends,only select first depend as parameter | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto anf = cnode->inputs()[input_idx]; | |||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | |||
| auto anf = origin_inputs[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| // anf has been created before | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| @@ -413,6 +419,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K | |||
| (*other_graph_cnode)[anf] = new_parameter; | |||
| } | |||
| continue; | |||
| } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | |||
| cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); | |||
| continue; | |||
| } else if (anf->isa<AnfNode>()) { | |||
| *from_other_graph = true; | |||
| // the input node is a cnode from other graph | |||
| @@ -28,6 +28,7 @@ | |||
| #include <string> | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/utils.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "operator/ops.h" | |||
| @@ -85,7 +86,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| if (lst.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input anf node list is empty"; | |||
| } | |||
| auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { | |||
| if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) { | |||
| eqv[a] = a; | |||
| @@ -95,17 +95,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| eqv[a]->set_abstract(a->abstract()); | |||
| eqv[a]->set_kernel_info(a->kernel_info_ptr()); | |||
| } | |||
| return eqv[a]; | |||
| }; | |||
| // Merge CNodes into a AnfGraph that represents a linear instruction segment | |||
| for (auto n : lst) { | |||
| if (!n->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Inst is not CNode"; | |||
| } | |||
| auto &inps = n->cast<CNodePtr>()->inputs(); | |||
| if (inps.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input is empty"; | |||
| } | |||
| @@ -114,21 +111,22 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| inps[0]->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { | |||
| MS_LOG(EXCEPTION) << "Input[0] Must be a Primitive valuenode"; | |||
| } | |||
| auto fn = inps[0]; | |||
| std::vector<AnfNodePtr> args{fn}; | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | |||
| if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && inps[kRealInputIndexInDepend]->isa<ValueNode>() && | |||
| eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | |||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||
| args.emplace_back(inps[kRealInputIndexInDepend]); | |||
| } else { | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); | |||
| } | |||
| eqv[n] = fg->NewCNode(args); | |||
| eqv[n]->set_abstract(n->abstract()); | |||
| eqv[n]->set_kernel_info(n->kernel_info_ptr()); | |||
| } | |||
| std::vector<AnfNodePtr> eqv_keys; | |||
| (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), | |||
| [](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; }); | |||
| auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); | |||
| AnfNodePtr fg_output; | |||
| if (outputs.size() > 1) { | |||
| @@ -136,29 +136,12 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n | |||
| } | |||
| } | |||
| bool IsGetItemNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| if (!IsValueNode<Primitive>(inputs[0])) { | |||
| return true; | |||
| } | |||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]); | |||
| return node_prim->name() == prim::kPrimTupleGetItem->name(); | |||
| } | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) { | |||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | |||
| std::map<AnfNodePtr, size_t> node_positions; | |||
| for (auto &node : nodes) { | |||
| if (IsGetItemNode(node)) { | |||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| @@ -241,7 +224,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||
| } | |||
| } | |||
| std::reverse(result.begin(), result.end()); | |||
| return ReorderGetItemNode(result); | |||
| return result; | |||
| } | |||
| } // namespace | |||
| @@ -309,19 +292,12 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = OptimizeGetItemOrder(input_nodes); | |||
| VectorRef splits; | |||
| VectorRef split; | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| if (ContainMultiTarget(nodes)) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| nodes = SplitSort(graph, default_target); | |||
| } | |||
| std::string last_target; | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| @@ -343,6 +319,36 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| return splits; | |||
| } | |||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| if (ContainMultiTarget(nodes)) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| nodes = SplitSort(graph, default_target); | |||
| return SplitNodesWithTarget(nodes, graph); | |||
| } | |||
| VectorRef splits; | |||
| VectorRef split; | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| if (split.size() != 0) { | |||
| splits.push_back(split); | |||
| } | |||
| splits.push_back(node); | |||
| split.clear(); | |||
| } else if (node->isa<CNode>()) { | |||
| split.push_back(node); | |||
| } | |||
| } | |||
| return splits; | |||
| } | |||
| // Push the value node on the stack. | |||
| void CompileGraph::Push(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -78,6 +78,7 @@ class CompileGraph { | |||
| } | |||
| private: | |||
| VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph); | |||
| void PushParameters(const FuncGraphPtr &func_graph); | |||
| bool SplitGraph(const FuncGraphPtr &func_graph); | |||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||