diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc index e79921f980..a83f327780 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc @@ -50,6 +50,32 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { return result; } +AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) { + AnfNodePtrList result; + for (auto node : nodes) { + if (node->abstract()->isa()) { + auto node_abstract = node->abstract()->cast()->elements(); + auto num = node_abstract.size(); + for (size_t i = 0; i < num; i++) { + auto idx_val = SizeToLong(i); + + auto idx = NewValueNode(idx_val); + MS_EXCEPTION_IF_NULL(idx); + idx->set_abstract(std::make_shared(idx_val)); + + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i)); + tuple_getitem->set_abstract(node_abstract[i]); + tuple_getitem->set_kernel_info(std::make_shared()); + result.push_back(tuple_getitem); + } + } else { + result.push_back(node); + } + } + return result; +} bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { auto todos = GetUpdateStateList(func_graph); bool changed = false; @@ -58,6 +84,8 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() <= kUpdateStateRealInput) continue; auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); + // extend inputs of update if which have multiple outputs + inputs = ExtendInputsOfUpdate(inputs, func_graph); if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) { AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)}; node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end()); @@ -81,7 +109,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) { for (auto node : todos) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() <= kUpdateStateRealInput) continue; + if (cnode->size() <= kUpdateStateRealInput + 1) continue; AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); AbstractBasePtrList abs_list; std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list), diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h index 3e8ccd0704..ef3be568e3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h @@ -39,6 +39,7 @@ class SpreadUpdateState : public Pass { public: SpreadUpdateState() : Pass("spread_update_state") {} ~SpreadUpdateState() override = default; + AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph); bool Run(const FuncGraphPtr &func_graph) override; };