| @@ -106,7 +106,11 @@ void BuildGraphTask::Run() { | |||
| void RunGraphTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| try { | |||
| auto graph = session_->GetGraph(graph_id_); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| graph->ResetGraphRunningStatus(); | |||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | |||
| graph->OnRunGraphFinished(); | |||
| UpdateOutputTensors(&outputs_, tensor_to_node_); | |||
| } catch (const std::exception &e) { | |||
| MsException::GetInstance().SetException(); | |||
| @@ -205,6 +209,7 @@ void Executor::OnRunGraphFinished() { | |||
| if (new_ready_tasks.size() > 0) { | |||
| task_cond_var_.notify_all(); | |||
| } | |||
| reenter_cond_var_.notify_all(); | |||
| } | |||
| bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) { | |||
| @@ -215,6 +220,12 @@ bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) { | |||
| return false; | |||
| } | |||
| } | |||
| auto session = task->session_; | |||
| MS_EXCEPTION_IF_NULL(session); | |||
| auto graph = session->GetGraph(task->graph_id_); | |||
| if (graph != nullptr) { | |||
| return graph->IsPreGraphFinished(); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -300,6 +311,14 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| SyncRunTask(task); | |||
| return; | |||
| } | |||
| auto graph = session->GetGraph(task->graph_id_); | |||
| if (graph != nullptr) { | |||
| if (!graph->IsPostGraphFinished()) { | |||
| mindspore::ScopedLongRunning long_running; | |||
| std::unique_lock<std::mutex> lock(reenter_mutex_); | |||
| reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); | |||
| } | |||
| } | |||
| bool ready = IsTaskReady(task); | |||
| if (!ready) { | |||
| @@ -179,8 +179,10 @@ class Executor { | |||
| std::string device_name_; | |||
| std::mutex task_mutex_; | |||
| std::mutex pending_task_mutex_; | |||
| std::mutex reenter_mutex_; | |||
| std::condition_variable task_cond_var_; | |||
| std::condition_variable sync_cond_var_; | |||
| std::condition_variable reenter_cond_var_; | |||
| std::queue<std::shared_ptr<Task>> ready_tasks_; | |||
| std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; | |||
| std::vector<std::shared_ptr<Task>> done_tasks_; | |||
| @@ -17,15 +17,16 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <queue> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <set> | |||
| #include <stack> | |||
| #include <unordered_set> | |||
| #include <stack> | |||
| #include <atomic> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/graph_utils.h" | |||
| @@ -50,6 +51,51 @@ class KernelGraph : public FuncGraph { | |||
| summary_node_exist_ = false; | |||
| stream_distinction_label_ = kInvalidDistincLabel; | |||
| } | |||
| KernelGraph(const KernelGraph &graph) : FuncGraph(graph) { | |||
| inputs_ = graph.inputs_; | |||
| child_graph_result_ = graph.child_graph_result_; | |||
| execution_order_ = graph.execution_order_; | |||
| graph_id_ = graph.graph_id_; | |||
| stream_distinction_label_ = graph.stream_distinction_label_; | |||
| front_backend_anf_map_ = graph.front_backend_anf_map_; | |||
| backend_front_anf_map_ = graph.backend_front_anf_map_; | |||
| tensor_to_value_node_map_ = graph.tensor_to_value_node_map_; | |||
| graph_value_nodes_ = graph.graph_value_nodes_; | |||
| node_input_num_ = graph.node_input_num_; | |||
| node_input_edges_ = graph.node_input_edges_; | |||
| ref_out_in_map_ = graph.ref_out_in_map_; | |||
| node_output_edges_ = graph.node_output_edges_; | |||
| summary_nodes_ = graph.summary_nodes_; | |||
| executable_ = graph.executable_; | |||
| summary_node_exist_ = graph.summary_node_exist_; | |||
| valid_inputs_ = graph.valid_inputs_; | |||
| child_graph_order_ = graph.child_graph_order_; | |||
| input_ctrl_tensors_ = graph.input_ctrl_tensors_; | |||
| parent_graph_ = graph.parent_graph_; | |||
| start_label_ = graph.start_label_; | |||
| end_goto_ = graph.end_goto_; | |||
| null_output_ = graph.null_output_; | |||
| front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_; | |||
| internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_; | |||
| internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_; | |||
| current_epoch_ = graph.current_epoch_; | |||
| tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_; | |||
| visited_nodes_ = graph.visited_nodes_; | |||
| edge_to_ = graph.edge_to_; | |||
| loop_nodes_ = graph.loop_nodes_; | |||
| input_nodes_ = graph.input_nodes_; | |||
| pre_graphs_ = graph.pre_graphs_; | |||
| post_graphs_ = graph.post_graphs_; | |||
| size_t pre_graph_finished_count = graph.pre_graph_finished_count_; | |||
| pre_graph_finished_count_ = pre_graph_finished_count; | |||
| size_t post_graph_finished_count = graph.post_graph_finished_count_; | |||
| post_graph_finished_count_ = post_graph_finished_count; | |||
| first_step_ = graph.first_step_; | |||
| has_optimizer_ = graph.has_optimizer_; | |||
| is_dynamic_shape_ = graph.is_dynamic_shape_; | |||
| } | |||
| ~KernelGraph() override; | |||
| MS_DECLARE_PARENT(KernelGraph, FuncGraph); | |||
| @@ -189,6 +235,47 @@ class KernelGraph : public FuncGraph { | |||
| void SetInputNodes(); | |||
| const std::vector<AnfNodePtr> &input_nodes() const { return input_nodes_; } | |||
| bool has_optimizer() const { return has_optimizer_; } | |||
| // handle graph dependency | |||
| void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) { | |||
| if (graph != nullptr) { | |||
| pre_graphs_[graph->graph_id()] = graph; | |||
| } | |||
| } | |||
| void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) { | |||
| if (graph != nullptr) { | |||
| post_graphs_[graph->graph_id()] = graph; | |||
| } | |||
| } | |||
| bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; } | |||
| bool IsPostGraphFinished() { | |||
| if (first_step_) { | |||
| return true; | |||
| } | |||
| return post_graphs_.size() == post_graph_finished_count_; | |||
| } | |||
| void IncPreGraphFinishedCount() { pre_graph_finished_count_++; } | |||
| void IncPostGraphFinishedCount() { post_graph_finished_count_++; } | |||
| void ResetGraphRunningStatus() { | |||
| first_step_ = false; | |||
| post_graph_finished_count_ = 0; | |||
| pre_graph_finished_count_ = 0; | |||
| } | |||
| void OnRunGraphFinished() { | |||
| for (auto post_graph : post_graphs_) { | |||
| auto post_graph_ptr = post_graph.second.lock(); | |||
| if (post_graph_ptr != nullptr) { | |||
| post_graph_ptr->IncPreGraphFinishedCount(); | |||
| } | |||
| } | |||
| for (auto pre_graph : pre_graphs_) { | |||
| auto pre_graph_ptr = pre_graph.second.lock(); | |||
| if (pre_graph_ptr != nullptr) { | |||
| pre_graph_ptr->IncPostGraphFinishedCount(); | |||
| } | |||
| } | |||
| } | |||
| // end of handle graph dependency | |||
| private: | |||
| // remove value node form graph | |||
| @@ -218,6 +305,7 @@ class KernelGraph : public FuncGraph { | |||
| uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes); | |||
| void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num); | |||
| // members | |||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | |||
| std::vector<AnfNodePtr> child_graph_result_; | |||
| std::vector<CNodePtr> execution_order_; | |||
| @@ -265,6 +353,11 @@ class KernelGraph : public FuncGraph { | |||
| std::map<AnfNodePtr, AnfNodePtr> edge_to_; | |||
| std::stack<AnfNodePtr> loop_nodes_; | |||
| std::vector<AnfNodePtr> input_nodes_; | |||
| std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_; | |||
| std::unordered_map<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_; | |||
| std::atomic<size_t> pre_graph_finished_count_{0}; | |||
| std::atomic<size_t> post_graph_finished_count_{0}; | |||
| bool first_step_{true}; | |||
| bool has_optimizer_{false}; | |||
| bool is_dynamic_shape_{false}; | |||
| }; | |||
| @@ -344,7 +344,7 @@ void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id | |||
| KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { | |||
| auto it = graphs_.find(graph_id); | |||
| if (it == graphs_.end()) { | |||
| MS_LOG(WARNING) << "Can't find graph " << graph_id; | |||
| MS_LOG(INFO) << "Can't find graph " << graph_id; | |||
| return nullptr; | |||
| } | |||
| return it->second; | |||
| @@ -57,11 +57,25 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std: | |||
| result.outputs = outputs; | |||
| result.graph_id = kInvalidGraphId; | |||
| GraphId graph_id = kInvalidGraphId; | |||
| auto current_session = target_sess_; | |||
| if (target != target_device_ && !target.empty()) { | |||
| CreateOtherSession(target); | |||
| graph_id = other_sess_->CompileGraph(segment, outputs); | |||
| } else { | |||
| graph_id = target_sess_->CompileGraph(segment, outputs); | |||
| current_session = other_sess_; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(current_session); | |||
| graph_id = current_session->CompileGraph(segment, outputs); | |||
| segment->graph_id_ = graph_id; | |||
| auto graph = current_session->GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (auto &pre_segment : segment->pre_segments_) { | |||
| MS_EXCEPTION_IF_NULL(pre_segment); | |||
| auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_); | |||
| if (pre_graph == nullptr) { | |||
| pre_graph = other_sess_->GetGraph(pre_segment->graph_id_); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pre_graph); | |||
| pre_graph->AddPostGraph(graph); | |||
| graph->AddPreGraph(pre_graph); | |||
| } | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | |||
| @@ -246,6 +246,55 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & | |||
| return result; | |||
| } | |||
| void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_target, | |||
| const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) { | |||
| std::stack<AnfNodePtr> to_visit; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||
| to_visit.push(graph->get_return()); | |||
| while (!to_visit.empty()) { | |||
| auto &node = to_visit.top(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| to_visit.pop(); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_inputs = cnode->inputs(); | |||
| auto ctrl_inputs = control_edges.find(node); | |||
| if (ctrl_inputs != control_edges.end()) { | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| GraphSegmentPtr node_segment{nullptr}; | |||
| auto node_iter = node_to_segment.find(node); | |||
| if (node_iter != node_to_segment.end()) { | |||
| node_segment = node_iter->second; | |||
| } | |||
| for (auto &input : node_inputs) { | |||
| if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) { | |||
| GraphSegmentPtr input_segment{nullptr}; | |||
| auto input_iter = node_to_segment.find(input); | |||
| if (input_iter != node_to_segment.end()) { | |||
| input_segment = input_iter->second; | |||
| } | |||
| if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) { | |||
| node_segment->AddPreSegment(input_segment); | |||
| } | |||
| } | |||
| auto ref_iter = nodes_ref.find(input); | |||
| if (ref_iter != nodes_ref.end()) { | |||
| ref_iter->second--; | |||
| if (ref_iter->second != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| to_visit.push(input); | |||
| } | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::stack<AnfNodePtr> handle_nodes; | |||
| @@ -404,10 +453,10 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| bool contain_multi_target = ContainMultiTarget(nodes); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (contain_multi_target) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (graph != nullptr) { | |||
| nodes = SplitSort(graph, default_target); | |||
| } else { | |||
| @@ -417,15 +466,22 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph | |||
| } | |||
| std::vector<GraphSegmentPtr> segments; | |||
| std::vector<AnfNodePtr> segment_nodes; | |||
| std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment; | |||
| auto new_segment = [&segments, &segment_nodes, &node_to_segment]() { | |||
| if (segment_nodes.size() != 0) { | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||
| segments.emplace_back(segment); | |||
| for (auto node : segment_nodes) { | |||
| node_to_segment[node] = segment; | |||
| } | |||
| segment_nodes.clear(); | |||
| } | |||
| }; | |||
| std::string last_target; | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| if (segment_nodes.size() != 0) { | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||
| segments.emplace_back(segment); | |||
| segment_nodes.clear(); | |||
| } | |||
| new_segment(); | |||
| segment_nodes.emplace_back(node); | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, true); | |||
| segments.push_back(segment); | |||
| @@ -433,10 +489,8 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph | |||
| } else if (node->isa<CNode>()) { | |||
| if (contain_multi_target) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) { | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||
| segments.emplace_back(segment); | |||
| segment_nodes.clear(); | |||
| if (cur_target != last_target && !last_target.empty()) { | |||
| new_segment(); | |||
| } | |||
| last_target = cur_target; | |||
| } | |||
| @@ -444,6 +498,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Segment size:" << segments.size(); | |||
| if (contain_multi_target) { | |||
| AddSegmentDependency(graph, default_target, node_to_segment); | |||
| } | |||
| return segments; | |||
| } | |||
| } // namespace compile | |||
| @@ -25,6 +25,7 @@ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <set> | |||
| #include "base/base.h" | |||
| #include "base/user_data.h" | |||
| @@ -485,8 +486,11 @@ std::string GetCNodeTarget(const AnfNodePtr &node); | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes); | |||
| struct GraphSegment { | |||
| GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {} | |||
| void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); } | |||
| std::vector<AnfNodePtr> nodes_; | |||
| std::set<std::shared_ptr<GraphSegment>> pre_segments_; | |||
| bool is_cut_{false}; | |||
| uint32_t graph_id_{0}; | |||
| }; | |||
| using GraphSegmentPtr = std::shared_ptr<GraphSegment>; | |||
| } // namespace mindspore | |||