|
|
|
@@ -22,6 +22,7 @@ |
|
|
|
#include <unordered_map> |
|
|
|
#include <unordered_set> |
|
|
|
#include <vector> |
|
|
|
#include <utility> |
|
|
|
|
|
|
|
#include "ir/func_graph.h" |
|
|
|
#include "ir/func_graph_cloner.h" |
|
|
|
@@ -372,7 +373,11 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
if (g2_ == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto tuple_getitem = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto new_g1 = getitem_transform_(g1_, idx_); |
|
|
|
auto new_g2 = getitem_transform_(g2_, idx_); |
|
|
|
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); |
|
|
|
@@ -423,6 +428,24 @@ class IncorporateGetitemSwitch : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
bool MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg) const { |
|
|
|
auto switch_call_cnode = switch_call->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(switch_call_cnode); |
|
|
|
auto manager = fg->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto node_users_map = manager->node_users(); |
|
|
|
auto it = node_users_map.find(switch_call); |
|
|
|
if (it == node_users_map.end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto node_users = it->second; |
|
|
|
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s |
|
|
|
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) { |
|
|
|
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem); |
|
|
|
}); |
|
|
|
return tuple_getitem_num > 1; |
|
|
|
} |
|
|
|
|
|
|
|
int idx_{-1}; |
|
|
|
AnfNodePtr switch_{nullptr}, x_{nullptr}; |
|
|
|
FuncGraphPtr g1_{nullptr}, g2_{nullptr}; |
|
|
|
|