diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index e8af57d764..ee4d67d5af 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -65,8 +65,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri result.outputs = outputs; result.graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId; - if (target == kCPUDevice) { - graph_id = cpu_sess_->CompileGraph(lst, outputs); + if (target != target_device_ && target != "") { + CreateOtherSession(target); + graph_id = other_sess_->CompileGraph(lst, outputs); } else { graph_id = target_sess_->CompileGraph(lst, outputs); } @@ -75,8 +76,8 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } - if (target == kCPUDevice) { - cpu_sess_->BuildGraph(graph_id); + if (target != target_device_ && target != "") { + other_sess_->BuildGraph(graph_id); } else if (!is_multi_graph_sink_) { target_sess_->BuildGraph(graph_id); } @@ -278,8 +279,8 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s VectorRef outputs; // call ms rungraph (graphId, input ,output) - if (target == kCPUDevice) { - cpu_sess_->RunGraph(g, inputs, &outputs); + if (target != target_device_ && target != "") { + other_sess_->RunGraph(g, inputs, &outputs); } else { target_sess_->RunGraph(g, inputs, &outputs); } @@ -341,16 +342,20 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_ } target_sess_->Init(device_id); target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); - if (target == kCPUDevice) { - cpu_sess_ = target_sess_; - } else { - cpu_sess_ = session::SessionFactory::Get().Create(kCPUDevice); - if (cpu_sess_ == nullptr) { - MS_LOG(EXCEPTION) << "Create cpu session failed with target " << target << "."; - } - cpu_sess_->Init(0); - cpu_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); + target_device_ = target; +} + +void MsBackend::CreateOtherSession(const std::string &target) { + if (other_sess_ != nullptr && other_device_ == target) { + return; + } + other_sess_ = session::SessionFactory::Get().Create(kCPUDevice); + if (other_sess_ == nullptr) { + MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } + other_sess_->Init(0); + other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); + other_device_ = target; } GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_->CompileGraph(fg); } diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 1ff345a43b..0e0b02c055 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -107,10 +107,13 @@ class MsBackend : public Backend { LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; GraphId CompileGraph(NotNull fg) override; VectorRef RunGraph(GraphId graph_id, const VectorRef &args); + void CreateOtherSession(const std::string &target); private: session::SessionPtr target_sess_; - session::SessionPtr cpu_sess_; + session::SessionPtr other_sess_; + std::string target_device_; + std::string other_device_; std::unordered_map simu_cond_map_; std::unordered_map graph_id_map_; std::unordered_map>, BaseRefHash> graph_inputs_; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 90efc0ac5f..732107beb4 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -75,7 +76,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { return default_target; } auto primitive = value->cast(); - ValuePtr att_target = primitive->GetAttr("target"); + ValuePtr att_target = primitive->GetAttr("primitive_target"); if (att_target != nullptr) { std::string target = GetValue(att_target); return target; @@ -127,6 +128,58 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n } } } + +std::vector SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { + std::vector result; + std::stack to_visit; + std::stack next_to_visit; + std::map nodes_ref; + CalcNodeRefCount(graph, &nodes_ref); + std::string handle_target = default_target; + std::string next_target = ""; + to_visit.push(graph->get_return()); + while (!to_visit.empty() || !next_to_visit.empty()) { + if (to_visit.empty()) { + to_visit.swap(next_to_visit); + handle_target = next_target; + } + auto &node = to_visit.top(); + to_visit.pop(); + MS_EXCEPTION_IF_NULL(node); + result.emplace_back(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto node_inputs = cnode->inputs(); + std::reverse(node_inputs.begin(), node_inputs.end()); + for (auto &input : node_inputs) { + auto iter = nodes_ref.find(input); + if (iter != nodes_ref.end()) { + iter->second--; + if (iter->second != 0) { + continue; + } + } + if (!input->isa()) { + to_visit.push(input); + continue; + } + std::string input_target = GetCNodeTarget(input); + if (input_target == handle_target) { + to_visit.push(input); + } else if (next_to_visit.empty() || input_target == next_target) { + next_to_visit.push(input); + next_target = input_target; + } else { + MS_LOG(EXCEPTION) << "only support two different target"; + } + } + } + std::reverse(result.begin(), result.end()); + return result; +} } // namespace CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) @@ -180,65 +233,16 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { return false; } -std::vector CompileGraph::SplitSort(const FuncGraphPtr &graph) { - std::vector result; - std::queue queue; - std::queue next_queue; - std::map nodes_ref; - CalcNodeRefCount(graph, &nodes_ref); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string queue_target = context_ptr->device_target(); - std::string next_target = ""; - queue.push(graph->get_return()); - while (!queue.empty() || !next_queue.empty()) { - if (queue.empty()) { - queue.swap(next_queue); - queue_target = next_target; - } - auto &node = queue.front(); - queue.pop(); - MS_EXCEPTION_IF_NULL(node); - result.emplace_back(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (auto &input : cnode->inputs()) { - auto iter = nodes_ref.find(input); - if (iter != nodes_ref.end()) { - iter->second--; - if (iter->second != 0) { - continue; - } - } - if (!input->isa()) { - queue.push(input); - continue; - } - std::string input_target = GetCNodeTarget(input); - if (input_target == queue_target) { - queue.push(input); - } else if (next_queue.empty() || input_target == next_target) { - next_queue.push(input); - next_target = input_target; - } else { - MS_LOG(EXCEPTION) << "only support two different target"; - } - } - } - std::reverse(result.begin(), result.end()); - return result; -} - VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); VectorRef splits; VectorRef split; auto nodes = TopoSort(graph->get_return()); if (ContainMultiTarget(nodes)) { - nodes = SplitSort(graph); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + nodes = SplitSort(graph, default_target); } std::string last_target; MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 7505a52ed1..3a1da0ff42 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -79,7 +79,6 @@ class CompileGraph { private: void PushParameters(const FuncGraphPtr &func_graph); - std::vector SplitSort(const FuncGraphPtr &graph); bool SplitGraph(const FuncGraphPtr &func_graph); int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);