|
|
|
@@ -39,6 +39,27 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
for (auto &root : roots) { |
|
|
|
auto tmp = DeepLinkedGraphSearch(root, include); |
|
|
|
inputs.insert(inputs.end(), tmp.begin(), tmp.end()); |
|
|
|
} |
|
|
|
return inputs; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include, |
|
|
|
const FuncGraphManagerPtr &mng) { |
|
|
|
std::vector<AnfNodePtr> users; |
|
|
|
for (auto &root : roots) { |
|
|
|
auto tmp = DeepUsersSearch(root, include, mng); |
|
|
|
users.insert(users.end(), tmp.begin(), tmp.end()); |
|
|
|
} |
|
|
|
return users; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { |
|
|
|
#if ENABLE_D |
|
|
|
std::vector<PrimitivePtr> basic_ops = { |
|
|
|
@@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, |
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) { |
|
|
|
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) { |
|
|
|
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
circle_nodes->clear(); |
|
|
|
|
|
|
|
std::set<AnfNodePtr> cached_done_set; |
|
|
|
auto cnode = check_node->cast<CNodePtr>(); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
// there is a input not in fused_op_set, but the input depends on the fused_op_set |
|
|
|
bool has_circle = false; |
|
|
|
for (auto input : inputs) { |
|
|
|
if (input->isa<CNode>() && !fused_op_set.count(input)) { |
|
|
|
bool has_circle = false; |
|
|
|
std::set<AnfNodePtr> done; |
|
|
|
std::vector<AnfNodePtr> todos = {input}; |
|
|
|
while (!todos.empty()) { |
|
|
|
auto node = todos.back(); |
|
|
|
todos.pop_back(); |
|
|
|
if (done.count(node) || cached_unconnected_set->count(node)) { |
|
|
|
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
done.insert(node); |
|
|
|
if (fused_op_set.count(node)) { |
|
|
|
has_circle = true; |
|
|
|
*circle_node = node; |
|
|
|
break; |
|
|
|
circle_nodes->push_back(node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (node->isa<CNode>()) { |
|
|
|
@@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che |
|
|
|
} |
|
|
|
|
|
|
|
if (has_circle) { |
|
|
|
return true; |
|
|
|
cached_done_set.insert(done.begin(), done.end()); |
|
|
|
} else { |
|
|
|
cached_unconnected_set->insert(done.begin(), done.end()); |
|
|
|
} |
|
|
|
cached_unconnected_set->insert(done.begin(), done.end()); |
|
|
|
done.clear(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
return !circle_nodes->empty(); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { |
|
|
|
@@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo |
|
|
|
} |
|
|
|
return EXCLUDE; |
|
|
|
}; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> circle_nodes; |
|
|
|
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { |
|
|
|
AnfNodePtr circle_node; |
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node); |
|
|
|
circle_nodes.clear(); |
|
|
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); |
|
|
|
// delete the circle node and the node which depend on the circle node in fused op |
|
|
|
if (has_circle) { |
|
|
|
auto mng = (*iter)->func_graph()->manager(); |
|
|
|
std::vector<AnfNodePtr> erase_nodes; |
|
|
|
if (is_backward) { |
|
|
|
erase_nodes = DeepUsersSearch(circle_node, include, mng); |
|
|
|
erase_nodes = DeepUsersSearch(circle_nodes, include, mng); |
|
|
|
} else { |
|
|
|
erase_nodes = DeepLinkedGraphSearch(circle_node, include); |
|
|
|
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); |
|
|
|
} |
|
|
|
for (auto erase_node : erase_nodes) { |
|
|
|
fused_op_set.erase(erase_node); |
|
|
|
|