|
|
@@ -129,6 +129,62 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsGetItemNode(const AnfNodePtr &node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
if (node->isa<CNode>()) { |
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
|
|
if (inputs.empty()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; |
|
|
|
|
|
} |
|
|
|
|
|
if (!IsValueNode<Primitive>(inputs[0])) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]); |
|
|
|
|
|
return node_prim->name() == prim::kPrimTupleGetItem->name(); |
|
|
|
|
|
} |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) { |
|
|
|
|
|
std::vector<AnfNodePtr> result; |
|
|
|
|
|
std::map<size_t, std::vector<AnfNodePtr>> insert_positions; |
|
|
|
|
|
std::map<AnfNodePtr, size_t> node_positions; |
|
|
|
|
|
for (auto &node : nodes) { |
|
|
|
|
|
if (IsGetItemNode(node)) { |
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
|
|
if (inputs.size() < 2) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid get item node"; |
|
|
|
|
|
} |
|
|
|
|
|
auto &parent = inputs[1]; |
|
|
|
|
|
auto iter = node_positions.find(parent); |
|
|
|
|
|
if (iter != node_positions.end()) { |
|
|
|
|
|
size_t position = iter->second; |
|
|
|
|
|
auto iter_nodes = insert_positions.find(position); |
|
|
|
|
|
if (iter_nodes != insert_positions.end()) { |
|
|
|
|
|
iter_nodes->second.push_back(node); |
|
|
|
|
|
} else { |
|
|
|
|
|
(void)insert_positions.insert( |
|
|
|
|
|
std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node})); |
|
|
|
|
|
} |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
result.emplace_back(node); |
|
|
|
|
|
node_positions[node] = result.size(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
size_t insert_num = 0; |
|
|
|
|
|
for (auto &item : insert_positions) { |
|
|
|
|
|
size_t position = item.first + insert_num; |
|
|
|
|
|
(void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); |
|
|
|
|
|
insert_num += item.second.size(); |
|
|
|
|
|
} |
|
|
|
|
|
return result; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { |
|
|
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { |
|
|
std::vector<AnfNodePtr> result; |
|
|
std::vector<AnfNodePtr> result; |
|
|
std::stack<AnfNodePtr> to_visit; |
|
|
std::stack<AnfNodePtr> to_visit; |
|
|
@@ -144,8 +200,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & |
|
|
handle_target = next_target; |
|
|
handle_target = next_target; |
|
|
} |
|
|
} |
|
|
auto &node = to_visit.top(); |
|
|
auto &node = to_visit.top(); |
|
|
to_visit.pop(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
to_visit.pop(); |
|
|
result.emplace_back(node); |
|
|
result.emplace_back(node); |
|
|
if (!node->isa<CNode>()) { |
|
|
if (!node->isa<CNode>()) { |
|
|
continue; |
|
|
continue; |
|
|
@@ -178,7 +234,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
std::reverse(result.begin(), result.end()); |
|
|
std::reverse(result.begin(), result.end()); |
|
|
return result; |
|
|
|
|
|
|
|
|
return ReorderGetItemNode(result); |
|
|
} |
|
|
} |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
|