diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 06ccb1e612..be6b8462e3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -375,7 +375,7 @@ class IncorporateGetitemSwitch : public AnfVisitor { } auto tuple_getitem = node->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - if (MultipleUseOfSwitch(tuple_getitem->input(1), fg)) { + if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg)) { return nullptr; } auto new_g1 = getitem_transform_(g1_, idx_); @@ -446,6 +446,14 @@ class IncorporateGetitemSwitch : public AnfVisitor { return tuple_getitem_num > 1; } + static bool inline ExistEnvNode(const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + auto &nodes = fg->value_nodes(); + return std::any_of(nodes.begin(), nodes.end(), [](const auto &node) { + return IsPrimitive(node.first, prim::kPrimEnvSetItem) || IsPrimitive(node.first, prim::kPrimEnvGetItem); + }); + } + int64_t idx_{-1}; AnfNodePtr switch_{nullptr}, x_{nullptr}; FuncGraphPtr g1_{nullptr}, g2_{nullptr};