| @@ -39,6 +39,27 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) { | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| for (auto &root : roots) { | |||||
| auto tmp = DeepLinkedGraphSearch(root, include); | |||||
| inputs.insert(inputs.end(), tmp.begin(), tmp.end()); | |||||
| } | |||||
| return inputs; | |||||
| } | |||||
| std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include, | |||||
| const FuncGraphManagerPtr &mng) { | |||||
| std::vector<AnfNodePtr> users; | |||||
| for (auto &root : roots) { | |||||
| auto tmp = DeepUsersSearch(root, include, mng); | |||||
| users.insert(users.end(), tmp.begin(), tmp.end()); | |||||
| } | |||||
| return users; | |||||
| } | |||||
| } // namespace | |||||
| bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { | ||||
| #if ENABLE_D | #if ENABLE_D | ||||
| std::vector<PrimitivePtr> basic_ops = { | std::vector<PrimitivePtr> basic_ops = { | ||||
| @@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI | |||||
| } | } | ||||
| bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, | bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, | ||||
| std::set<AnfNodePtr> *cached_unconnected_set, AnfNodePtr *circle_node) { | |||||
| std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) { | |||||
| if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { | if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| circle_nodes->clear(); | |||||
| std::set<AnfNodePtr> cached_done_set; | |||||
| auto cnode = check_node->cast<CNodePtr>(); | auto cnode = check_node->cast<CNodePtr>(); | ||||
| const auto &inputs = cnode->inputs(); | const auto &inputs = cnode->inputs(); | ||||
| // there is a input not in fused_op_set, but the input depends on the fused_op_set | // there is a input not in fused_op_set, but the input depends on the fused_op_set | ||||
| bool has_circle = false; | |||||
| for (auto input : inputs) { | for (auto input : inputs) { | ||||
| if (input->isa<CNode>() && !fused_op_set.count(input)) { | if (input->isa<CNode>() && !fused_op_set.count(input)) { | ||||
| bool has_circle = false; | |||||
| std::set<AnfNodePtr> done; | std::set<AnfNodePtr> done; | ||||
| std::vector<AnfNodePtr> todos = {input}; | std::vector<AnfNodePtr> todos = {input}; | ||||
| while (!todos.empty()) { | while (!todos.empty()) { | ||||
| auto node = todos.back(); | auto node = todos.back(); | ||||
| todos.pop_back(); | todos.pop_back(); | ||||
| if (done.count(node) || cached_unconnected_set->count(node)) { | |||||
| if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| done.insert(node); | done.insert(node); | ||||
| if (fused_op_set.count(node)) { | if (fused_op_set.count(node)) { | ||||
| has_circle = true; | has_circle = true; | ||||
| *circle_node = node; | |||||
| break; | |||||
| circle_nodes->push_back(node); | |||||
| continue; | |||||
| } | } | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| @@ -224,13 +247,15 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che | |||||
| } | } | ||||
| if (has_circle) { | if (has_circle) { | ||||
| return true; | |||||
| cached_done_set.insert(done.begin(), done.end()); | |||||
| } else { | |||||
| cached_unconnected_set->insert(done.begin(), done.end()); | |||||
| } | } | ||||
| cached_unconnected_set->insert(done.begin(), done.end()); | |||||
| done.clear(); | |||||
| } | } | ||||
| } | } | ||||
| return false; | |||||
| return !circle_nodes->empty(); | |||||
| } | } | ||||
| std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { | std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { | ||||
| @@ -242,17 +267,19 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo | |||||
| } | } | ||||
| return EXCLUDE; | return EXCLUDE; | ||||
| }; | }; | ||||
| std::vector<AnfNodePtr> circle_nodes; | |||||
| for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { | for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { | ||||
| AnfNodePtr circle_node; | |||||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node); | |||||
| circle_nodes.clear(); | |||||
| bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); | |||||
| // delete the circle node and the node which depend on the circle node in fused op | // delete the circle node and the node which depend on the circle node in fused op | ||||
| if (has_circle) { | if (has_circle) { | ||||
| auto mng = (*iter)->func_graph()->manager(); | auto mng = (*iter)->func_graph()->manager(); | ||||
| std::vector<AnfNodePtr> erase_nodes; | std::vector<AnfNodePtr> erase_nodes; | ||||
| if (is_backward) { | if (is_backward) { | ||||
| erase_nodes = DeepUsersSearch(circle_node, include, mng); | |||||
| erase_nodes = DeepUsersSearch(circle_nodes, include, mng); | |||||
| } else { | } else { | ||||
| erase_nodes = DeepLinkedGraphSearch(circle_node, include); | |||||
| erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); | |||||
| } | } | ||||
| for (auto erase_node : erase_nodes) { | for (auto erase_node : erase_nodes) { | ||||
| fused_op_set.erase(erase_node); | fused_op_set.erase(erase_node); | ||||