Merge pull request !5611 from Margaret_wangrui/get_looptags/v1.0.0
| @@ -279,6 +279,59 @@ std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() { | |||||
| return re_order; | return re_order; | ||||
| } | } | ||||
| void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto node_input_it = node_input_edges_.find(node); | |||||
| if (node_input_it == node_input_edges_.end()) { | |||||
| MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges."; | |||||
| return; | |||||
| } | |||||
| visited_nodes_.insert(node); | |||||
| for (auto input_edge : node_input_edges_[node]) { | |||||
| size_t input_num = node_input_num_[input_edge.first]; | |||||
| if (input_num == 0) { | |||||
| continue; | |||||
| } | |||||
| if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) { | |||||
| MS_EXCEPTION_IF_NULL(input_edge.first); | |||||
| edge_to_[input_edge.first] = node; | |||||
| GetLoopNodesByDFS(input_edge.first, loop_num); | |||||
| } else { | |||||
| AnfNodePtr node_iter = node; | |||||
| MS_EXCEPTION_IF_NULL(node_iter); | |||||
| MS_LOG(DEBUG) << "Print loop nodes start:"; | |||||
| for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) { | |||||
| MS_EXCEPTION_IF_NULL(node_iter); | |||||
| loop_nodes_.push(node_iter); | |||||
| node_input_num_[node_iter]--; | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | |||||
| } | |||||
| loop_nodes_.push(node_iter); | |||||
| loop_nodes_.push(node); | |||||
| (*loop_num)++; | |||||
| node_input_num_[node_iter]--; | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString(); | |||||
| MS_LOG(DEBUG) << "Get loop node:" << node->DebugString(); | |||||
| MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num; | |||||
| } | |||||
| } | |||||
| } | |||||
| uint32_t KernelGraph::GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes) { | |||||
| uint32_t loop_num = 0; | |||||
| for (auto iter = none_zero_nodes.begin(); iter != none_zero_nodes.end(); iter++) { | |||||
| auto node = iter->first; | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node_input_num_[node] == 0) { | |||||
| continue; | |||||
| } | |||||
| edge_to_.clear(); | |||||
| visited_nodes_.clear(); | |||||
| GetLoopNodesByDFS(node, &loop_num); | |||||
| } | |||||
| return loop_num; | |||||
| } | |||||
| void KernelGraph::CheckLoop() { | void KernelGraph::CheckLoop() { | ||||
| std::map<AnfNodePtr, size_t> none_zero_nodes; | std::map<AnfNodePtr, size_t> none_zero_nodes; | ||||
| if (node_input_edges_.size() != node_input_num_.size()) { | if (node_input_edges_.size() != node_input_num_.size()) { | ||||
| @@ -303,6 +356,7 @@ void KernelGraph::CheckLoop() { | |||||
| } | } | ||||
| // if don't consider control depend and loop exit,a exception will be throw | // if don't consider control depend and loop exit,a exception will be throw | ||||
| if (!none_zero_nodes.empty()) { | if (!none_zero_nodes.empty()) { | ||||
| MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); | |||||
| MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); | MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <stack> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| @@ -90,8 +91,6 @@ class KernelGraph : public FuncGraph { | |||||
| void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); | void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); | ||||
| // get map | // get map | ||||
| std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; } | std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; } | ||||
| // checkout whether loop exist in graph | |||||
| void CheckLoop(); | |||||
| // check whether graph is executable | // check whether graph is executable | ||||
| bool executable() const { return executable_; } | bool executable() const { return executable_; } | ||||
| // set executable of graph | // set executable of graph | ||||
| @@ -199,6 +198,10 @@ class KernelGraph : public FuncGraph { | |||||
| AnfNodePtr TransCNodeTuple(const CNodePtr &node); | AnfNodePtr TransCNodeTuple(const CNodePtr &node); | ||||
| AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); | AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); | ||||
| std::vector<CNodePtr> SortStartLabelAndEndGoto(); | std::vector<CNodePtr> SortStartLabelAndEndGoto(); | ||||
| // checkout whether loop exist in graph | |||||
| void CheckLoop(); | |||||
| uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes); | |||||
| void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num); | |||||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | ||||
| std::vector<AnfNodePtr> child_graph_result_; | std::vector<AnfNodePtr> child_graph_result_; | ||||
| @@ -243,6 +246,10 @@ class KernelGraph : public FuncGraph { | |||||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | ||||
| uint32_t current_epoch_; | uint32_t current_epoch_; | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; | std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; | ||||
| std::set<AnfNodePtr> visited_nodes_; | |||||
| std::map<AnfNodePtr, AnfNodePtr> edge_to_; | |||||
| std::stack<AnfNodePtr> loop_nodes_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||