|
|
@@ -16,6 +16,7 @@ |
|
|
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h" |
|
|
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h" |
|
|
#include <algorithm> |
|
|
#include <algorithm> |
|
|
#include <vector> |
|
|
#include <vector> |
|
|
|
|
|
#include <set> |
|
|
#include <string> |
|
|
#include <string> |
|
|
#include <unordered_set> |
|
|
#include <unordered_set> |
|
|
#include <utility> |
|
|
#include <utility> |
|
|
@@ -50,18 +51,24 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { |
|
|
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { |
|
|
auto &users = mng->node_users(); |
|
|
|
|
|
AnfNodePtrList splitted_nodes; |
|
|
|
|
|
for (size_t i = 0; i < users[node].size(); ++i) { |
|
|
|
|
|
splitted_nodes.push_back(CloneCNode(node)); |
|
|
|
|
|
|
|
|
const auto &index_set = mng->node_users()[node]; |
|
|
|
|
|
std::map<AnfNodePtr, std::vector<int>> users_info; |
|
|
|
|
|
std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) { |
|
|
|
|
|
users_info[iter.first].push_back(iter.second); |
|
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
AnfNodePtrList split_nodes; |
|
|
|
|
|
for (size_t i = 0; i < users_info.size(); ++i) { |
|
|
|
|
|
split_nodes.push_back(CloneCNode(node)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const auto &index_set = users[node]; |
|
|
|
|
|
int i = 0; |
|
|
int i = 0; |
|
|
for (auto [user, index] : index_set) { |
|
|
|
|
|
|
|
|
for (auto [user, indices] : users_info) { |
|
|
auto user_node = user->cast<CNodePtr>(); |
|
|
auto user_node = user->cast<CNodePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(user_node); |
|
|
MS_EXCEPTION_IF_NULL(user_node); |
|
|
user_node->set_input(index, splitted_nodes[i]); |
|
|
|
|
|
|
|
|
for (auto index : indices) { |
|
|
|
|
|
user_node->set_input(index, split_nodes[i]); |
|
|
|
|
|
} |
|
|
i++; |
|
|
i++; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -69,9 +76,11 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { |
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { |
|
|
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { |
|
|
auto &users = mng->node_users(); |
|
|
auto &users = mng->node_users(); |
|
|
return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) { |
|
|
|
|
|
return IsPrimitiveCNode(node, prim); |
|
|
|
|
|
}); |
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> user_set; |
|
|
|
|
|
std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()), |
|
|
|
|
|
[](const std::pair<AnfNodePtr, int> &iter) { return iter.first; }); |
|
|
|
|
|
return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), |
|
|
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { |
|
|
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { |
|
|
|