diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 61d4d85eb6..51a2d3b64f 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -727,7 +727,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, MS_EXCEPTION_IF_NULL(other_graph_cnode); MS_EXCEPTION_IF_NULL(cnode_inputs); auto origin_inputs = cnode->inputs(); - bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3; + const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend); + const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState); // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { auto anf = origin_inputs[input_idx]; @@ -736,7 +737,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; - } else if (optimize_depend && input_idx > 1) { + } else if ((is_depend && input_idx > 1) || (is_updatestate && input_idx > 2)) { cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx)))); continue; } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 60d0e35d06..482453231f 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -39,43 +40,39 @@ namespace compile { ConvertCache g_ConvertCache; void ClearConvertCache() { g_ConvertCache.clear(); } +namespace { // Return the list of nodes whose values are required beyond this segment. // Arguments: -// lst: list of nodes (the segment) +// nodes: list of nodes in the segment // users: dict mapping each node to its users (globally) // seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector &seen) { +AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, + const std::unordered_set &seen) { AnfNodePtrList output; if (users.size() == 0) { return output; } - - (void)std::transform( - std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr { - auto usersn = users.find(n); - bool is_referred_out_of_segment = std::any_of( - std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair &u) -> bool { - return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen); - }); - if (n->isa() && is_referred_out_of_segment) { - return n; - } - return nullptr; - }); - - // remove nullptr - for (auto it = output.begin(); it != output.end();) { - if (*it == nullptr) { - it = output.erase(it); - } else { - ++it; + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + auto iter = users.find(node); + if (iter == users.end()) { + continue; + } + auto &node_users = iter->second; + const bool has_outer_user = std::any_of( + std::begin(node_users), std::end(node_users), [&seen](const std::pair &u) -> bool { + const bool is_outer_user = (seen.find(u.first) == seen.end()); + return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2); + }); + if (has_outer_user) { + output.emplace_back(node); } } - return output; } -namespace { AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr, AnfNodePtrToAnfNodePtrMap *eqv_ptr) { MS_EXCEPTION_IF_NULL(fg); @@ -129,6 +126,15 @@ std::tuple TransformSegmentToAnfGr for (size_t i = 2; i < inps.size(); ++i) { args.emplace_back(NewValueNode(MakeValue(0))); } + } else if (IsPrimitive(fn, prim::kPrimUpdateState)) { + args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv)); + args.emplace_back(RefSubGraphNode(fg, inps[2], &inputs, &eqv)); + for (size_t i = 3; i < inps.size(); ++i) { + auto &input = inps[i]; + if (eqv.find(input) != eqv.end()) { + args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv)); + } + } } else { (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); @@ -138,8 +144,8 @@ std::tuple TransformSegmentToAnfGr eqv[n]->set_abstract(n->abstract()); eqv[n]->set_kernel_info(n->kernel_info_ptr()); } - std::vector eqv_keys; - (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), + std::unordered_set eqv_keys; + (void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()), [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); AnfNodePtr fg_output; diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index b7324bb4ea..97849a6a16 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -465,7 +465,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { } } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { auto &inputs = cnode->inputs(); - if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) { + if (inputs.size() >= 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) { return GetCNodeTarget(inputs[2]); } } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {