From ee1b15f0cddd50fcdb1fb384f536845fd9a46d65 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Mon, 31 Aug 2020 16:31:43 +0800 Subject: [PATCH] get the set of points which forming loop --- .../ccsrc/backend/session/kernel_graph.cc | 54 +++++++++++++++++++ .../ccsrc/backend/session/kernel_graph.h | 11 +++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 9cba4ba6de..3093753091 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -279,6 +279,59 @@ std::vector KernelGraph::SortStartLabelAndEndGoto() { return re_order; } +void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { + MS_EXCEPTION_IF_NULL(node); + auto node_input_it = node_input_edges_.find(node); + if (node_input_it == node_input_edges_.end()) { + MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; + return; + } + visited_nodes_.insert(node); + for (auto input_edge : node_input_edges_[node]) { + size_t input_num = node_input_num_[input_edge.first]; + if (input_num == 0) { + continue; + } + if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) { + MS_EXCEPTION_IF_NULL(input_edge.first); + edge_to_[input_edge.first] = node; + GetLoopNodesByDFS(input_edge.first, loop_num); + } else { + AnfNodePtr node_iter = node; + MS_EXCEPTION_IF_NULL(node_iter); + MS_LOG(DEBUG) << "Print loop nodes start:"; + for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) { + MS_EXCEPTION_IF_NULL(node_iter); + loop_nodes_.push(node_iter); + node_input_num_[node_iter]--; + MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); + } + loop_nodes_.push(node_iter); + loop_nodes_.push(node); + (*loop_num)++; + node_input_num_[node_iter]--; + MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); + MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); + MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; + } + } +} + +uint32_t KernelGraph::GetLoopNum(std::map none_zero_nodes) { + uint32_t loop_num = 0; + for (auto iter = none_zero_nodes.begin(); iter != none_zero_nodes.end(); iter++) { + auto node = iter->first; + MS_EXCEPTION_IF_NULL(node); + if (node_input_num_[node] == 0) { + continue; + } + edge_to_.clear(); + visited_nodes_.clear(); + GetLoopNodesByDFS(node, &loop_num); + } + return loop_num; +} + void KernelGraph::CheckLoop() { std::map none_zero_nodes; if (node_input_edges_.size() != node_input_num_.size()) { @@ -303,6 +356,7 @@ void KernelGraph::CheckLoop() { } // if don't consider control depend and loop exit,a exception will be throw if (!none_zero_nodes.empty()) { + MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); } } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 46e38b4ac6..f3bfb78ef6 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "ir/func_graph.h" #include "ir/anf.h" @@ -90,8 +91,6 @@ class KernelGraph : public FuncGraph { void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); // get map std::map GetRefMap() const { return ref_out_in_map_; } - // checkout whether loop exist in graph - void CheckLoop(); // check whether graph is executable bool executable() const { return executable_; } // set executable of graph @@ -199,6 +198,10 @@ class KernelGraph : public FuncGraph { AnfNodePtr TransCNodeTuple(const CNodePtr &node); AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); std::vector SortStartLabelAndEndGoto(); + // checkout whether loop exist in graph + void CheckLoop(); + uint32_t GetLoopNum(std::map none_zero_nodes); + void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num); std::shared_ptr> inputs_; std::vector child_graph_result_; @@ -243,6 +246,10 @@ class KernelGraph : public FuncGraph { std::unordered_map> internal_outputs_tensor_map_; uint32_t current_epoch_; std::unordered_map tuple_parameter_to_make_tuple_map_; + + std::set visited_nodes_; + std::map edge_to_; + std::stack loop_nodes_; }; } // namespace session using KernelGraphPtr = std::shared_ptr;