| @@ -376,7 +376,10 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||||
| } | } | ||||
| auto tuple_getitem = node->cast<CNodePtr>(); | auto tuple_getitem = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | MS_EXCEPTION_IF_NULL(tuple_getitem); | ||||
| if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg)) { | |||||
| // If exist env_getitem/env_setitem in this funcgraph or | |||||
| // if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; | |||||
| if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && | |||||
| !ExistEnvNodeInTupleItem(g2_)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto new_g1 = getitem_transform_(g1_, idx_); | auto new_g1 = getitem_transform_(g1_, idx_); | ||||
| @@ -455,6 +458,23 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||||
| }); | }); | ||||
| } | } | ||||
| static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) { | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| const auto &output = fg->output(); | |||||
| if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { | |||||
| return false; | |||||
| } | |||||
| const auto &cnode = output->cast<CNodePtr>(); | |||||
| const auto &inputs = cnode->inputs(); | |||||
| return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) { | |||||
| auto sub_fg = GetValueNode<FuncGraphPtr>(node); | |||||
| if (sub_fg != nullptr && ExistEnvNode(sub_fg)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| } | |||||
| int64_t idx_{-1}; | int64_t idx_{-1}; | ||||
| AnfNodePtr switch_{nullptr}, x_{nullptr}; | AnfNodePtr switch_{nullptr}, x_{nullptr}; | ||||
| FuncGraphPtr g1_{nullptr}, g2_{nullptr}; | FuncGraphPtr g1_{nullptr}, g2_{nullptr}; | ||||