| @@ -39,6 +39,27 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| for (auto &root : roots) { | |||
| auto tmp = DeepLinkedGraphSearch(root, include); | |||
| inputs.insert(inputs.end(), tmp.begin(), tmp.end()); | |||
| } | |||
| return inputs; | |||
| } | |||
| std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include, | |||
| const FuncGraphManagerPtr &mng) { | |||
| std::vector<AnfNodePtr> users; | |||
| for (auto &root : roots) { | |||
| auto tmp = DeepUsersSearch(root, include, mng); | |||
| users.insert(users.end(), tmp.begin(), tmp.end()); | |||
| } | |||
| return users; | |||
| } | |||
| } // namespace | |||
| bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | |||
| #if ENABLE_D | |||
| std::vector<PrimitivePtr> basic_ops = { | |||
| @@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI | |||
| } | |||
| bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, | |||
| std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) { | |||
| std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) { | |||
| if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { | |||
| return false; | |||
| } | |||
| circle_nodes->clear(); | |||
| std::set<AnfNodePtr> cached_done_set; | |||
| auto cnode = check_node->cast<CNodePtr>(); | |||
| const auto &inputs = cnode->inputs(); | |||
| // there is a input not in fused_op_set, but the input depends on the fused_op_set | |||
| bool has_circle = false; | |||
| for (auto input : inputs) { | |||
| if (input->isa<CNode>() && !fused_op_set.count(input)) { | |||
| bool has_circle = false; | |||
| std::set<AnfNodePtr> done; | |||
| std::vector<AnfNodePtr> todos = {input}; | |||
| while (!todos.empty()) { | |||
| auto node = todos.back(); | |||
| todos.pop_back(); | |||
| if (done.count(node) || cached_unconnected_set->count(node)) { | |||
| if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) { | |||
| continue; | |||
| } | |||
| done.insert(node); | |||
| if (fused_op_set.count(node)) { | |||
| has_circle = true; | |||
| *circle_node = node; | |||
| break; | |||
| circle_nodes->push_back(node); | |||
| continue; | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| @@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che | |||
| } | |||
| if (has_circle) { | |||
| return true; | |||
| cached_done_set.insert(done.begin(), done.end()); | |||
| } else { | |||
| cached_unconnected_set->insert(done.begin(), done.end()); | |||
| } | |||
| cached_unconnected_set->insert(done.begin(), done.end()); | |||
| done.clear(); | |||
| } | |||
| } | |||
| return false; | |||
| return !circle_nodes->empty(); | |||
| } | |||
| std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { | |||
| @@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo | |||
| } | |||
| return EXCLUDE; | |||
| }; | |||
| std::vector<AnfNodePtr> circle_nodes; | |||
| for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { | |||
| AnfNodePtr circle_node; | |||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node); | |||
| circle_nodes.clear(); | |||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); | |||
| // delete the circle node and the node which depend on the circle node in fused op | |||
| if (has_circle) { | |||
| auto mng = (*iter)->func_graph()->manager(); | |||
| std::vector<AnfNodePtr> erase_nodes; | |||
| if (is_backward) { | |||
| erase_nodes = DeepUsersSearch(circle_node, include, mng); | |||
| erase_nodes = DeepUsersSearch(circle_nodes, include, mng); | |||
| } else { | |||
| erase_nodes = DeepLinkedGraphSearch(circle_node, include); | |||
| erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); | |||
| } | |||
| for (auto erase_node : erase_nodes) { | |||
| fused_op_set.erase(erase_node); | |||