| @@ -605,7 +605,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||||
| MS_EXCEPTION_IF_NULL(other_graph_cnode); | MS_EXCEPTION_IF_NULL(other_graph_cnode); | ||||
| MS_EXCEPTION_IF_NULL(cnode_inputs); | MS_EXCEPTION_IF_NULL(cnode_inputs); | ||||
| auto origin_inputs = cnode->inputs(); | auto origin_inputs = cnode->inputs(); | ||||
| bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3; | |||||
| bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3; | |||||
| bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; | bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; | ||||
| // if has multiple depends,only select first depend as parameter | // if has multiple depends,only select first depend as parameter | ||||
| for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { | ||||
| @@ -615,7 +615,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | ||||
| cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | ||||
| continue; | continue; | ||||
| } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { | |||||
| } else if (optimize_depend && input_idx > 1) { | |||||
| cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); | cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); | ||||
| continue; | continue; | ||||
| } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { | } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { | ||||
| @@ -214,7 +214,9 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto node_inputs = cnode->inputs(); | auto node_inputs = cnode->inputs(); | ||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { | |||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||||
| } | |||||
| auto ctrl_inputs = control_edges.find(node); | auto ctrl_inputs = control_edges.find(node); | ||||
| if (ctrl_inputs != control_edges.end()) { | if (ctrl_inputs != control_edges.end()) { | ||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | ||||
| @@ -139,9 +139,11 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| } | } | ||||
| auto fn = inps[0]; | auto fn = inps[0]; | ||||
| std::vector<AnfNodePtr> args{fn}; | std::vector<AnfNodePtr> args{fn}; | ||||
| if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | |||||
| if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { | |||||
| args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); | args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); | ||||
| args.emplace_back(NewValueNode(MakeValue(0))); | |||||
| for (size_t i = 2; i < inps.size(); ++i) { | |||||
| args.emplace_back(NewValueNode(MakeValue(0))); | |||||
| } | |||||
| } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { | ||||
| for (size_t i = 1; i < inps.size(); ++i) { | for (size_t i = 1; i < inps.size(); ++i) { | ||||
| if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { | ||||