From 5d0b3597ff2454999dfc830d08a6a8a475d45502 Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 30 Oct 2020 19:40:23 +0800 Subject: [PATCH] incorporate switch pass should handle mutiple getitem --- .../optimizer/irpass/incorporate_getitem.h | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index a59a569a74..ee62593d62 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" @@ -372,7 +373,11 @@ class IncorporateGetitemSwitch : public AnfVisitor { if (g2_ == nullptr) { return nullptr; } - + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (MultipleUseOfSwitch(tuple_getitem->input(1), fg)) { + return nullptr; + } auto new_g1 = getitem_transform_(g1_, idx_); auto new_g2 = getitem_transform_(g2_, idx_); auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); @@ -423,6 +428,24 @@ class IncorporateGetitemSwitch : public AnfVisitor { } private: + bool MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg) const { + auto switch_call_cnode = switch_call->cast(); + MS_EXCEPTION_IF_NULL(switch_call_cnode); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_users_map = manager->node_users(); + auto it = node_users_map.find(switch_call); + if (it == node_users_map.end()) { + return false; + } + auto node_users = it->second; + // If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s + auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair &user) { + return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem); + }); + return tuple_getitem_num > 1; + } + int idx_{-1}; AnfNodePtr switch_{nullptr}, x_{nullptr}; FuncGraphPtr g1_{nullptr}, g2_{nullptr};