| @@ -89,7 +89,8 @@ bool TensorInVector(const VectorRef *outputs) { | |||||
| void CompileNodesTask::Run() { | void CompileNodesTask::Run() { | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_); | |||||
| MS_EXCEPTION_IF_NULL(segment_); | |||||
| graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_); | |||||
| } | } | ||||
| void CompileGraphTask::Run() { | void CompileGraphTask::Run() { | ||||
| @@ -226,10 +227,11 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { | |||||
| MsException::GetInstance().CheckException(); | MsException::GetInstance().CheckException(); | ||||
| } | } | ||||
| GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||||
| GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, | |||||
| const AnfNodePtrList &outputs) { | |||||
| auto task = std::make_shared<CompileNodesTask>(); | auto task = std::make_shared<CompileNodesTask>(); | ||||
| task->session_ = session; | task->session_ = session; | ||||
| task->nodes_ = lst; | |||||
| task->segment_ = segment; | |||||
| task->output_nodes_ = outputs; | task->output_nodes_ = outputs; | ||||
| SyncRunTask(task); | SyncRunTask(task); | ||||
| return task->graph_id_; | return task->graph_id_; | ||||
| @@ -63,7 +63,7 @@ class CompileNodesTask : public Task { | |||||
| CompileNodesTask() { type_ = kCompileNodes; } | CompileNodesTask() { type_ = kCompileNodes; } | ||||
| ~CompileNodesTask() override = default; | ~CompileNodesTask() override = default; | ||||
| void Run() override; | void Run() override; | ||||
| AnfNodePtrList nodes_; | |||||
| GraphSegmentPtr segment_; | |||||
| AnfNodePtrList output_nodes_; | AnfNodePtrList output_nodes_; | ||||
| GraphId graph_id_{0}; | GraphId graph_id_{0}; | ||||
| }; | }; | ||||
| @@ -151,7 +151,7 @@ class Executor { | |||||
| ~Executor(); | ~Executor(); | ||||
| void WorkerLoop(); | void WorkerLoop(); | ||||
| void WorkerJoin(); | void WorkerJoin(); | ||||
| GraphId CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||||
| GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); | |||||
| GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); | GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); | ||||
| void BuildGraph(const SessionPtr &session, GraphId graphId); | void BuildGraph(const SessionPtr &session, GraphId graphId); | ||||
| void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| @@ -1388,9 +1388,9 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||||
| GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) { | |||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| return executor_->CompileGraph(shared_from_this(), lst, outputs); | |||||
| return executor_->CompileGraph(shared_from_this(), segment, outputs); | |||||
| } | } | ||||
| GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | ||||
| @@ -68,7 +68,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual ~SessionBasic() { summary_callback_ = nullptr; } | virtual ~SessionBasic() { summary_callback_ = nullptr; } | ||||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||||
| GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); | |||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | ||||
| void BuildGraph(GraphId graphId); | void BuildGraph(GraphId graphId); | ||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | ||||
| @@ -102,6 +102,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {} | virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {} | ||||
| std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id, | std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id, | ||||
| const std::vector<tensor::TensorPtr> &inputs); | const std::vector<tensor::TensorPtr> &inputs); | ||||
| // Get graph by graph id, if not exist return null ptr | |||||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| // set debugger | // set debugger | ||||
| void SetDebugger() { | void SetDebugger() { | ||||
| @@ -147,8 +149,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | ||||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | ||||
| // Get graph by graph id ,if not exist return null ptr | |||||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||||
| virtual void SetSummaryNodes(KernelGraph *graph); | virtual void SetSummaryNodes(KernelGraph *graph); | ||||
| @@ -354,7 +354,7 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| std::string backend = MsContext::GetInstance()->backend_policy(); | std::string backend = MsContext::GetInstance()->backend_policy(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | |||||
| if (func_graph->ContainMultiTarget()) { | |||||
| bc_ptr->set_is_multi_graph_sink(false); | bc_ptr->set_is_multi_graph_sink(false); | ||||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | ||||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); | context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); | ||||
| @@ -923,7 +923,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| MS_EXCEPTION_IF_NULL(convert_fn); | MS_EXCEPTION_IF_NULL(convert_fn); | ||||
| // Convert CNodeList to LinConvertResult. | // Convert CNodeList to LinConvertResult. | ||||
| ConfigManager::GetInstance().set_iter_num(1); | ConfigManager::GetInstance().set_iter_num(1); | ||||
| auto runner = convert_fn({app_init}, ""); | |||||
| auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false); | |||||
| auto runner = convert_fn(segment, ""); | |||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| backend->Link(runner.graph_id); | backend->Link(runner.graph_id); | ||||
| } | } | ||||
| @@ -34,30 +34,34 @@ namespace compile { | |||||
| bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | ||||
| bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | ||||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { | |||||
| Backend::Backend(const std::string &name) : name_(name) { | |||||
| MS_LOG(DEBUG) << "select backend:" << name; | |||||
| convert_fn_ = MsVmConvert; | |||||
| is_multi_graph_sink_ = false; | |||||
| } | |||||
| LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) { | |||||
| MS_LOG(DEBUG) << "MsConvert"; | MS_LOG(DEBUG) << "MsConvert"; | ||||
| MS_EXCEPTION_IF_NULL(segment); | |||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | ||||
| auto cached = g_ConvertCache.find(lst); | |||||
| auto cached = g_ConvertCache.find(segment); | |||||
| if (cached != g_ConvertCache.end()) { | if (cached != g_ConvertCache.end()) { | ||||
| return cached->second; | return cached->second; | ||||
| } | } | ||||
| LinConvertResult result; | LinConvertResult result; | ||||
| FuncGraphPtr fg; | FuncGraphPtr fg; | ||||
| AnfNodePtrList inputs; | AnfNodePtrList inputs; | ||||
| AnfNodePtrList outputs; | AnfNodePtrList outputs; | ||||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); | |||||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||||
| result.inputs = inputs; | result.inputs = inputs; | ||||
| result.outputs = outputs; | result.outputs = outputs; | ||||
| result.graph_id = kInvalidGraphId; | result.graph_id = kInvalidGraphId; | ||||
| GraphId graph_id = kInvalidGraphId; | GraphId graph_id = kInvalidGraphId; | ||||
| if (target != target_device_ && !target.empty()) { | if (target != target_device_ && !target.empty()) { | ||||
| CreateOtherSession(target); | CreateOtherSession(target); | ||||
| graph_id = other_sess_->CompileGraph(lst, outputs); | |||||
| graph_id = other_sess_->CompileGraph(segment, outputs); | |||||
| } else { | } else { | ||||
| graph_id = target_sess_->CompileGraph(lst, outputs); | |||||
| graph_id = target_sess_->CompileGraph(segment, outputs); | |||||
| } | } | ||||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | ||||
| @@ -79,7 +83,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||||
| result.graph_id = graph_id; | result.graph_id = graph_id; | ||||
| graph_id_map_[graph_id] = result; | graph_id_map_[graph_id] = result; | ||||
| (void)g_ConvertCache.emplace(lst, result); | |||||
| (void)g_ConvertCache.emplace(segment, result); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -154,12 +158,6 @@ void MsBackend::Link(GraphId graph_id) { | |||||
| target_sess_->BuildGraph(graph_id); | target_sess_->BuildGraph(graph_id); | ||||
| } | } | ||||
| Backend::Backend(const std::string &name) : name_(name) { | |||||
| MS_LOG(DEBUG) << "select backend:" << name; | |||||
| convert_fn_ = backends[name_]; | |||||
| is_multi_graph_sink_ = false; | |||||
| } | |||||
| MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | ||||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); | convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); | ||||
| target_sess_ = session::SessionFactory::Get().Create(target); | target_sess_ = session::SessionFactory::Get().Create(target); | ||||
| @@ -194,6 +192,5 @@ VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } | void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } | ||||
| #endif | #endif | ||||
| } // namespace compile | } // namespace compile | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "utils/contract.h" | #include "utils/contract.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "vm/graph_partition.h" | |||||
| #include "vm/vm.h" | #include "vm/vm.h" | ||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| @@ -63,7 +64,7 @@ class MsBackend : public Backend { | |||||
| MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | ||||
| ~MsBackend() override = default; | ~MsBackend() override = default; | ||||
| LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); | |||||
| LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); | |||||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | ||||
| VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | ||||
| @@ -0,0 +1,447 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "vm/graph_partition.h" | |||||
| #include <string> | |||||
| #include <functional> | |||||
| #include <utility> | |||||
| #include <map> | |||||
| #include <queue> | |||||
| #include <stack> | |||||
| #include <set> | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/utils.h" | |||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | |||||
| const char kMsConvert[] = "ms"; | |||||
| const char kMsVm[] = "vm"; | |||||
| const char kGeVm[] = "ge"; | |||||
| namespace compile { | |||||
| namespace { | |||||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(behind_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| if (prior_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[prior_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(prior_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (behind_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[behind_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(behind_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto input_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||||
| int64_t depend_mode = 0; | |||||
| if (mode_ptr != nullptr) { | |||||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||||
| } | |||||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> prior_nodes; | |||||
| std::vector<AnfNodePtr> behind_nodes; | |||||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||||
| return; | |||||
| } | |||||
| for (auto &first_node : prior_nodes) { | |||||
| for (auto &second_node : behind_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| auto iter = control_edges->find(second_node); | |||||
| if (iter == control_edges->end()) { | |||||
| (void)control_edges->insert( | |||||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||||
| } else { | |||||
| iter->second.emplace_back(first_node); | |||||
| } | |||||
| auto ref_iter = nodes_ref->find(first_node); | |||||
| if (ref_iter != nodes_ref->end()) { | |||||
| ref_iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||||
| std::queue<AnfNodePtr> queue; | |||||
| queue.push(graph->get_return()); | |||||
| std::set<AnfNodePtr> visited; | |||||
| while (!queue.empty()) { | |||||
| auto &node = queue.front(); | |||||
| queue.pop(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (auto &input : cnode->inputs()) { | |||||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||||
| } | |||||
| auto iter = nodes_ref->find(input); | |||||
| if (iter != nodes_ref->end()) { | |||||
| iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||||
| } | |||||
| if (visited.find(input) != visited.end()) { | |||||
| continue; | |||||
| } | |||||
| visited.insert(input); | |||||
| queue.push(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | |||||
| std::map<AnfNodePtr, size_t> node_positions; | |||||
| for (auto &node : nodes) { | |||||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "Invalid get item node"; | |||||
| } | |||||
| auto &parent = inputs[1]; | |||||
| auto iter = node_positions.find(parent); | |||||
| if (iter != node_positions.end()) { | |||||
| size_t position = iter->second; | |||||
| auto iter_nodes = insert_positions.find(position); | |||||
| if (iter_nodes != insert_positions.end()) { | |||||
| iter_nodes->second.push_back(node); | |||||
| } else { | |||||
| (void)insert_positions.insert( | |||||
| std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node})); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| } | |||||
| result.emplace_back(node); | |||||
| node_positions[node] = result.size(); | |||||
| } | |||||
| size_t insert_num = 0; | |||||
| for (auto &item : insert_positions) { | |||||
| size_t position = item.first + insert_num; | |||||
| (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); | |||||
| insert_num += item.second.size(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::stack<AnfNodePtr> to_visit; | |||||
| std::stack<AnfNodePtr> next_to_visit; | |||||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||||
| std::string handle_target = default_target; | |||||
| std::string next_target = ""; | |||||
| to_visit.push(graph->get_return()); | |||||
| while (!to_visit.empty() || !next_to_visit.empty()) { | |||||
| if (to_visit.empty()) { | |||||
| to_visit.swap(next_to_visit); | |||||
| handle_target = next_target; | |||||
| } | |||||
| auto &node = to_visit.top(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| to_visit.pop(); | |||||
| result.emplace_back(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto node_inputs = cnode->inputs(); | |||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||||
| auto ctrl_inputs = control_edges.find(node); | |||||
| if (ctrl_inputs != control_edges.end()) { | |||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||||
| } | |||||
| for (auto &input : node_inputs) { | |||||
| auto iter = nodes_ref.find(input); | |||||
| if (iter != nodes_ref.end()) { | |||||
| iter->second--; | |||||
| if (iter->second != 0) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (!input->isa<CNode>()) { | |||||
| to_visit.push(input); | |||||
| continue; | |||||
| } | |||||
| std::string input_target = GetCNodeTarget(input); | |||||
| if (input_target == handle_target) { | |||||
| to_visit.push(input); | |||||
| } else if (next_to_visit.empty() || input_target == next_target) { | |||||
| next_to_visit.push(input); | |||||
| next_target = input_target; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Only support two different target"; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::reverse(result.begin(), result.end()); | |||||
| return result; | |||||
| } | |||||
| std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::stack<AnfNodePtr> handle_nodes; | |||||
| std::stack<AnfNodePtr> next_handle_nodes; | |||||
| std::stack<AnfNodePtr> wait_handle_nodes; | |||||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||||
| std::string handle_target = default_target; | |||||
| std::string next_target = ""; | |||||
| handle_nodes.push(graph->get_return()); | |||||
| while (!handle_nodes.empty() || !next_handle_nodes.empty() || !wait_handle_nodes.empty()) { | |||||
| if (handle_nodes.empty()) { | |||||
| handle_nodes.swap(next_handle_nodes); | |||||
| handle_target.swap(next_target); | |||||
| next_handle_nodes.swap(wait_handle_nodes); | |||||
| } | |||||
| auto &node = handle_nodes.top(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| handle_nodes.pop(); | |||||
| result.emplace_back(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto node_inputs = cnode->inputs(); | |||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||||
| auto ctrl_inputs = control_edges.find(node); | |||||
| if (ctrl_inputs != control_edges.end()) { | |||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||||
| } | |||||
| std::vector<AnfNodePtr> same_target_ready_inputs; | |||||
| std::vector<AnfNodePtr> diff_target_ready_inputs; | |||||
| for (auto &input : node_inputs) { | |||||
| auto iter = nodes_ref.find(input); | |||||
| if (iter != nodes_ref.end()) { | |||||
| iter->second--; | |||||
| if (iter->second != 0) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (!input->isa<CNode>()) { | |||||
| handle_nodes.push(input); | |||||
| continue; | |||||
| } | |||||
| std::string input_target = GetCNodeTarget(input); | |||||
| if (input_target == handle_target) { | |||||
| same_target_ready_inputs.emplace_back(input); | |||||
| } else { | |||||
| if (next_target.empty()) { | |||||
| next_target = input_target; | |||||
| } | |||||
| if (input_target != next_target) { | |||||
| MS_LOG(EXCEPTION) << "Only support two different target"; | |||||
| } | |||||
| diff_target_ready_inputs.emplace_back(input); | |||||
| } | |||||
| } | |||||
| if (diff_target_ready_inputs.size() == 0) { | |||||
| for (auto &input : same_target_ready_inputs) { | |||||
| handle_nodes.push(input); | |||||
| } | |||||
| } else { | |||||
| for (auto &input : same_target_ready_inputs) { | |||||
| wait_handle_nodes.push(input); | |||||
| } | |||||
| for (auto &input : diff_target_ready_inputs) { | |||||
| next_handle_nodes.push(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::reverse(result.begin(), result.end()); | |||||
| return result; | |||||
| } | |||||
| bool IsSubGraph(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = inputs[0]; | |||||
| if (!IsValueNode<Primitive>(fn)) { | |||||
| return false; | |||||
| } | |||||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||||
| return true; | |||||
| } | |||||
| } else if (IsValueNode<FuncGraph>(node)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name) | |||||
| : cut_list_(cut_list), backend_name_(backend_name) {} | |||||
| bool GraphPartition::IsCut(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = inputs[0]; | |||||
| if (IsValueNode<FuncGraph>(fn)) { | |||||
| auto fg = GetValueNode<FuncGraphPtr>(fn); | |||||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (!IsValueNode<Primitive>(fn)) { | |||||
| return true; | |||||
| } | |||||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn); | |||||
| for (auto &prim : cut_list_) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == node_prim->name()) { | |||||
| if (prim->name() == prim::kPrimBpropCut->name()) { | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true); | |||||
| } | |||||
| if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||||
| if (inputs.size() < 2) { | |||||
| return false; | |||||
| } | |||||
| auto ret = IsSubGraph(inputs[1]); | |||||
| return ret; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } | |||||
| #ifdef ENABLE_GE | |||||
| if (backend_name_ == kGeVm) { | |||||
| auto name = GetCNodeFuncName(cnode); | |||||
| auto adpt = transform::DfGraphConvertor::FindAdapter(name); | |||||
| if (adpt == nullptr) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||||
| bool contain_multi_target = ContainMultiTarget(nodes); | |||||
| if (contain_multi_target) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| if (graph != nullptr) { | |||||
| nodes = SplitSort(graph, default_target); | |||||
| } else { | |||||
| nodes = ParallelSplitSort(graph, default_target); | |||||
| } | |||||
| nodes = OptimizeGetItemOrder(nodes); | |||||
| } | |||||
| std::vector<GraphSegmentPtr> segments; | |||||
| std::vector<AnfNodePtr> segment_nodes; | |||||
| std::string last_target; | |||||
| for (auto &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsCut(node)) { | |||||
| if (segment_nodes.size() != 0) { | |||||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||||
| segments.emplace_back(segment); | |||||
| segment_nodes.clear(); | |||||
| } | |||||
| segment_nodes.emplace_back(node); | |||||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, true); | |||||
| segments.push_back(segment); | |||||
| segment_nodes.clear(); | |||||
| } else if (node->isa<CNode>()) { | |||||
| if (contain_multi_target) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) { | |||||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||||
| segments.emplace_back(segment); | |||||
| segment_nodes.clear(); | |||||
| } | |||||
| last_target = cur_target; | |||||
| } | |||||
| segment_nodes.emplace_back(node); | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Segment size:" << segments.size(); | |||||
| return segments; | |||||
| } | |||||
| } // namespace compile | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||||
| #define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ir/anf.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/graph_utils.h" | |||||
| #include "base/base_ref.h" | |||||
| namespace mindspore { | |||||
| extern const char kMsVm[]; | |||||
| extern const char kGeVm[]; | |||||
| extern const char kMsConvert[]; | |||||
| namespace compile { | |||||
| class GraphPartition { | |||||
| public: | |||||
| explicit GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name); | |||||
| ~GraphPartition() = default; | |||||
| std::vector<GraphSegmentPtr> Partition(const FuncGraphPtr &func_graph); | |||||
| private: | |||||
| bool IsCut(const AnfNodePtr &node); | |||||
| std::vector<PrimitivePtr> cut_list_; | |||||
| std::string backend_name_; | |||||
| }; | |||||
| using GraphPartitionPtr = std::shared_ptr<GraphPartition>; | |||||
| } // namespace compile | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||||
| @@ -34,10 +34,6 @@ | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| const char kMsConvert[] = "ms"; | |||||
| const char kMsVm[] = "vm"; | |||||
| const char kGeVm[] = "ge"; | |||||
| namespace compile { | namespace compile { | ||||
| // cached conversion | // cached conversion | ||||
| ConvertCache g_ConvertCache; | ConvertCache g_ConvertCache; | ||||
| @@ -194,8 +190,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||||
| // This implementation will convert the nodes into a subgraph | // This implementation will convert the nodes into a subgraph | ||||
| // that will run using the MsVM. | // that will run using the MsVM. | ||||
| template <typename T> | template <typename T> | ||||
| LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||||
| auto cached = g_ConvertCache.find(lst); | |||||
| LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) { | |||||
| MS_EXCEPTION_IF_NULL(segment); | |||||
| auto cached = g_ConvertCache.find(segment); | |||||
| if (cached != g_ConvertCache.end()) { | if (cached != g_ConvertCache.end()) { | ||||
| return cached->second; | return cached->second; | ||||
| } | } | ||||
| @@ -206,7 +203,7 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||||
| AnfNodePtrList inputs; | AnfNodePtrList inputs; | ||||
| AnfNodePtrList outputs; | AnfNodePtrList outputs; | ||||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); | |||||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||||
| // Clone in case g contains subgraphs that have a different manager | // Clone in case g contains subgraphs that have a different manager | ||||
| fg = BasicClone(fg); | fg = BasicClone(fg); | ||||
| @@ -219,18 +216,15 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||||
| result.outputs = outputs; | result.outputs = outputs; | ||||
| result.graph_id = UINT32_MAX; | result.graph_id = UINT32_MAX; | ||||
| (void)g_ConvertCache.emplace(lst, result); | |||||
| (void)g_ConvertCache.emplace(segment, result); | |||||
| return result; | return result; | ||||
| } | } | ||||
| LinkFuncType MsVmConvert = Convert<VM>; | LinkFuncType MsVmConvert = Convert<VM>; | ||||
| std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}}; | |||||
| std::set<std::string> backend_list = { | std::set<std::string> backend_list = { | ||||
| kMsConvert, | kMsConvert, | ||||
| kMsVm, | kMsVm, | ||||
| }; | }; | ||||
| } // namespace compile | } // namespace compile | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,14 +27,10 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "vm/vmimpl.h" | #include "vm/vmimpl.h" | ||||
| #include "vm/graph_partition.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| extern const char kMsVm[]; | |||||
| extern const char kGeVm[]; | |||||
| extern const char kMsConvert[]; | |||||
| namespace compile { | namespace compile { | ||||
| struct LinConvertResult { | struct LinConvertResult { | ||||
| RunFuncPtr run; | RunFuncPtr run; | ||||
| RunFuncPtr simu_run; | RunFuncPtr simu_run; | ||||
| @@ -43,11 +39,9 @@ struct LinConvertResult { | |||||
| uint32_t graph_id; | uint32_t graph_id; | ||||
| }; | }; | ||||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>; | |||||
| using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; | |||||
| using LinkFuncType = std::function<LinConvertResult(const GraphSegmentPtr &, const std::string &)>; | |||||
| using ConvertCache = std::unordered_map<GraphSegmentPtr, LinConvertResult>; | |||||
| extern LinkFuncType MsVmConvert; | extern LinkFuncType MsVmConvert; | ||||
| extern LinkFuncType GeVmConvert; | |||||
| extern std::unordered_map<std::string, LinkFuncType> backends; | |||||
| extern ConvertCache g_ConvertCache; | extern ConvertCache g_ConvertCache; | ||||
| extern std::set<std::string> backend_list; | extern std::set<std::string> backend_list; | ||||
| @@ -21,8 +21,6 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <map> | #include <map> | ||||
| #include <queue> | #include <queue> | ||||
| #include <stack> | |||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -52,386 +50,13 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||||
| return ms_nonlinear_ops; | return ms_nonlinear_ops; | ||||
| } | } | ||||
| namespace { | |||||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| for (auto &node : nodes) { | |||||
| if (node->isa<CNode>()) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (last_target != cur_target) { | |||||
| return true; | |||||
| } | |||||
| last_target = cur_target; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(behind_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| if (prior_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[prior_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||||
| prior_nodes->emplace_back(prior_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (behind_node->isa<Parameter>()) { | |||||
| for (auto &user : node_users[behind_node]) { | |||||
| auto cnode = user.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(cnode); | |||||
| } | |||||
| } | |||||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||||
| depend_nodes->emplace_back(behind_node); | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto input_cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||||
| MS_EXCEPTION_IF_NULL(prior_node); | |||||
| MS_EXCEPTION_IF_NULL(depend_node); | |||||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||||
| int64_t depend_mode = 0; | |||||
| if (mode_ptr != nullptr) { | |||||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||||
| } | |||||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> prior_nodes; | |||||
| std::vector<AnfNodePtr> behind_nodes; | |||||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||||
| return; | |||||
| } | |||||
| for (auto &first_node : prior_nodes) { | |||||
| for (auto &second_node : behind_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(first_node); | |||||
| MS_EXCEPTION_IF_NULL(second_node); | |||||
| auto iter = control_edges->find(second_node); | |||||
| if (iter == control_edges->end()) { | |||||
| (void)control_edges->insert( | |||||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||||
| } else { | |||||
| iter->second.emplace_back(first_node); | |||||
| } | |||||
| auto ref_iter = nodes_ref->find(first_node); | |||||
| if (ref_iter != nodes_ref->end()) { | |||||
| ref_iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||||
| std::queue<AnfNodePtr> queue; | |||||
| queue.push(graph->get_return()); | |||||
| std::set<AnfNodePtr> visited; | |||||
| while (!queue.empty()) { | |||||
| auto &node = queue.front(); | |||||
| queue.pop(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (auto &input : cnode->inputs()) { | |||||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||||
| } | |||||
| auto iter = nodes_ref->find(input); | |||||
| if (iter != nodes_ref->end()) { | |||||
| iter->second++; | |||||
| } else { | |||||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||||
| } | |||||
| if (visited.find(input) != visited.end()) { | |||||
| continue; | |||||
| } | |||||
| visited.insert(input); | |||||
| queue.push(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | |||||
| std::map<AnfNodePtr, size_t> node_positions; | |||||
| for (auto &node : nodes) { | |||||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "Invalid get item node"; | |||||
| } | |||||
| auto &parent = inputs[1]; | |||||
| auto iter = node_positions.find(parent); | |||||
| if (iter != node_positions.end()) { | |||||
| size_t position = iter->second; | |||||
| auto iter_nodes = insert_positions.find(position); | |||||
| if (iter_nodes != insert_positions.end()) { | |||||
| iter_nodes->second.push_back(node); | |||||
| } else { | |||||
| (void)insert_positions.insert( | |||||
| std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node})); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| } | |||||
| result.emplace_back(node); | |||||
| node_positions[node] = result.size(); | |||||
| } | |||||
| size_t insert_num = 0; | |||||
| for (auto &item : insert_positions) { | |||||
| size_t position = item.first + insert_num; | |||||
| (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); | |||||
| insert_num += item.second.size(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||||
| std::vector<AnfNodePtr> result; | |||||
| std::stack<AnfNodePtr> to_visit; | |||||
| std::stack<AnfNodePtr> next_to_visit; | |||||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||||
| std::string handle_target = default_target; | |||||
| std::string next_target = ""; | |||||
| to_visit.push(graph->get_return()); | |||||
| while (!to_visit.empty() || !next_to_visit.empty()) { | |||||
| if (to_visit.empty()) { | |||||
| to_visit.swap(next_to_visit); | |||||
| handle_target = next_target; | |||||
| } | |||||
| auto &node = to_visit.top(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| to_visit.pop(); | |||||
| result.emplace_back(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto node_inputs = cnode->inputs(); | |||||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||||
| auto ctrl_inputs = control_edges.find(node); | |||||
| if (ctrl_inputs != control_edges.end()) { | |||||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||||
| } | |||||
| for (auto &input : node_inputs) { | |||||
| auto iter = nodes_ref.find(input); | |||||
| if (iter != nodes_ref.end()) { | |||||
| iter->second--; | |||||
| if (iter->second != 0) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| if (!input->isa<CNode>()) { | |||||
| to_visit.push(input); | |||||
| continue; | |||||
| } | |||||
| std::string input_target = GetCNodeTarget(input); | |||||
| if (input_target == handle_target) { | |||||
| to_visit.push(input); | |||||
| } else if (next_to_visit.empty() || input_target == next_target) { | |||||
| next_to_visit.push(input); | |||||
| next_target = input_target; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "only support two different target"; | |||||
| } | |||||
| } | |||||
| } | |||||
| std::reverse(result.begin(), result.end()); | |||||
| return result; | |||||
| } | |||||
| bool IsSubGraph(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = inputs[0]; | |||||
| if (!IsValueNode<Primitive>(fn)) { | |||||
| return false; | |||||
| } | |||||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||||
| return true; | |||||
| } | |||||
| } else if (IsValueNode<FuncGraph>(node)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | |||||
| : backend_(backend), cut_list_(cut_list) { | |||||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) { | |||||
| MS_EXCEPTION_IF_NULL(backend_); | MS_EXCEPTION_IF_NULL(backend_); | ||||
| lin_convert_ = backend_->convert_fn(); | lin_convert_ = backend_->convert_fn(); | ||||
| if (lin_convert_ == nullptr) { | if (lin_convert_ == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name(); | MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name(); | ||||
| } | } | ||||
| is_gevm_convert_ = false; | |||||
| if (backend->name() == kGeVm) { | |||||
| MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true"; | |||||
| is_gevm_convert_ = true; | |||||
| } | |||||
| } | |||||
| bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->isa<CNode>()) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = inputs[0]; | |||||
| if (IsValueNode<FuncGraph>(fn)) { | |||||
| auto fg = GetValueNode<FuncGraphPtr>(fn); | |||||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (!IsValueNode<Primitive>(fn)) { | |||||
| return true; | |||||
| } | |||||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn); | |||||
| for (auto &prim : cut_list_) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == node_prim->name()) { | |||||
| if (prim->name() == prim::kPrimBpropCut->name()) { | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true); | |||||
| } | |||||
| if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||||
| if (inputs.size() < 2) { | |||||
| return false; | |||||
| } | |||||
| auto ret = IsSubGraph(inputs[1]); | |||||
| return ret; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } | |||||
| #ifdef ENABLE_GE | |||||
| if (is_gevm_convert_) { | |||||
| auto name = GetCNodeFuncName(cnode); | |||||
| auto adpt = transform::DfGraphConvertor::FindAdapter(name); | |||||
| if (adpt == nullptr) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| return false; | |||||
| } | |||||
| VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto nodes = OptimizeGetItemOrder(input_nodes); | |||||
| VectorRef splits; | |||||
| VectorRef split; | |||||
| std::string last_target; | |||||
| for (auto &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsCut(node)) { | |||||
| if (split.size() != 0) { | |||||
| splits.push_back(split); | |||||
| } | |||||
| splits.push_back(node); | |||||
| split.clear(); | |||||
| } else if (node->isa<CNode>()) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (cur_target != last_target && !last_target.empty() && split.size() != 0) { | |||||
| splits.push_back(split); | |||||
| split.clear(); | |||||
| } | |||||
| last_target = cur_target; | |||||
| split.push_back(node); | |||||
| } | |||||
| } | |||||
| return splits; | |||||
| } | |||||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| nodes = SplitSort(graph, default_target); | |||||
| return SplitNodesWithTarget(nodes, graph); | |||||
| } | |||||
| VectorRef splits; | |||||
| VectorRef split; | |||||
| for (auto &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (IsCut(node)) { | |||||
| if (split.size() != 0) { | |||||
| splits.push_back(split); | |||||
| } | |||||
| splits.push_back(node); | |||||
| split.clear(); | |||||
| } else if (node->isa<CNode>()) { | |||||
| split.push_back(node); | |||||
| } | |||||
| } | |||||
| return splits; | |||||
| graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name()); | |||||
| } | } | ||||
| // Push the value node on the stack. | // Push the value node on the stack. | ||||
| @@ -512,12 +137,12 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| } | } | ||||
| int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, | |||||
| const std::string &target) { | |||||
| int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) { | |||||
| MS_EXCEPTION_IF_NULL(segment); | |||||
| MS_LOG(DEBUG) << "LinConvert start"; | MS_LOG(DEBUG) << "LinConvert start"; | ||||
| LinConvertResult result; | LinConvertResult result; | ||||
| result = lin_convert_(node_list, target); | |||||
| result = lin_convert_(segment, target); | |||||
| if (result.run == nullptr) { | if (result.run == nullptr) { | ||||
| MS_LOG(ERROR) << "LinConvert failed"; | MS_LOG(ERROR) << "LinConvert failed"; | ||||
| @@ -583,25 +208,23 @@ int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &n | |||||
| return RET_SUCCESS; | return RET_SUCCESS; | ||||
| } | } | ||||
| bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||||
| bool CompileGraph::Compile(const FuncGraphPtr &graph) { | |||||
| MS_LOG(DEBUG) << "Start split graph"; | MS_LOG(DEBUG) << "Start split graph"; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| VectorRef splits = SplitNodes(graph); | |||||
| MS_EXCEPTION_IF_NULL(graph_partition_); | |||||
| auto segments = graph_partition_->Partition(graph); | |||||
| MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); | |||||
| for (auto &split : splits) { | |||||
| MS_LOG(DEBUG) << "Split nodes size:" << segments.size(); | |||||
| for (auto &segment : segments) { | |||||
| MS_EXCEPTION_IF_NULL(segment); | |||||
| int64_t ret = RET_SUCCESS; | int64_t ret = RET_SUCCESS; | ||||
| if (utils::isa<VectorRef>(split)) { | |||||
| if (!segment->is_cut_) { | |||||
| MS_LOG(DEBUG) << "Start a extern LinConvert"; | MS_LOG(DEBUG) << "Start a extern LinConvert"; | ||||
| std::vector<AnfNodePtr> args; | |||||
| auto vec_ref = utils::cast<VectorRef>(split); | |||||
| (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), | |||||
| [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); | |||||
| if (args.size() > 0) { | |||||
| std::string cur_target = GetCNodeTarget(args[0]); | |||||
| ret = LinConvert(graph, args, cur_target); | |||||
| if (segment->nodes_.size() > 0) { | |||||
| std::string cur_target = GetCNodeTarget(segment->nodes_[0]); | |||||
| ret = LinConvert(graph, segment, cur_target); | |||||
| } else { | } else { | ||||
| ret = LinConvert(graph, args); | |||||
| ret = LinConvert(graph, segment); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "End a extern LinConvert"; | MS_LOG(DEBUG) << "End a extern LinConvert"; | ||||
| if (ret == RET_FAILED) { | if (ret == RET_FAILED) { | ||||
| @@ -612,10 +235,11 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "Start a cut node"; | MS_LOG(DEBUG) << "Start a cut node"; | ||||
| if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) { | |||||
| auto &cut_node = segment->nodes_[0]; | |||||
| if (!cut_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info()); | MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info()); | ||||
| } | } | ||||
| CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>(); | |||||
| CNodePtr node = cut_node->cast<CNodePtr>(); | |||||
| ret = InterpretNode(graph, node); | ret = InterpretNode(graph, node); | ||||
| MS_LOG(DEBUG) << "End a cut node"; | MS_LOG(DEBUG) << "End a cut node"; | ||||
| if (ret == RET_BREAK) { | if (ret == RET_BREAK) { | ||||
| @@ -635,7 +259,7 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) { | |||||
| int64_t param_height = height_; | int64_t param_height = height_; | ||||
| MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); | MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); | ||||
| if (!SplitGraph(graph)) { | |||||
| if (!Compile(graph)) { | |||||
| return inst_; | return inst_; | ||||
| } | } | ||||
| @@ -897,20 +521,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { | |||||
| return rt; | return rt; | ||||
| } | } | ||||
| bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto graph_manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||||
| for (auto &g : graphs) { | |||||
| auto nodes = TopoSort(g->get_return()); | |||||
| if (ContainMultiTarget(nodes)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| BackendPtr CreateBackend() { | BackendPtr CreateBackend() { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "vm/backend.h" | #include "vm/backend.h" | ||||
| #include "vm/graph_partition.h" | |||||
| // mindspore namespace is the top level namespace of MindSpore project. | // mindspore namespace is the top level namespace of MindSpore project. | ||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | // Other namespace should be a sub namespace of mindspore namespace in the ME project. | ||||
| @@ -59,7 +60,6 @@ class CompileGraph { | |||||
| void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } | void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } | ||||
| void Ret(int64_t nargs); | void Ret(int64_t nargs); | ||||
| int64_t Ref(const AnfNodePtr &node); | int64_t Ref(const AnfNodePtr &node); | ||||
| VectorRef SplitNodes(const FuncGraphPtr &func_graph); | |||||
| void set_height(int64_t h) { | void set_height(int64_t h) { | ||||
| height_ = h; | height_ = h; | ||||
| @@ -76,10 +76,9 @@ class CompileGraph { | |||||
| } | } | ||||
| private: | private: | ||||
| VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph); | |||||
| void PushParameters(const FuncGraphPtr &func_graph); | void PushParameters(const FuncGraphPtr &func_graph); | ||||
| bool SplitGraph(const FuncGraphPtr &func_graph); | |||||
| int64_t LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||||
| bool Compile(const FuncGraphPtr &func_graph); | |||||
| int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); | |||||
| int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | ||||
| int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | ||||
| void AddPadStack(int64_t param_height); | void AddPadStack(int64_t param_height); | ||||
| @@ -97,11 +96,12 @@ class CompileGraph { | |||||
| void AddInst(const Instruction &inst, const VectorRef &args); | void AddInst(const Instruction &inst, const VectorRef &args); | ||||
| BackendPtr backend_; | BackendPtr backend_; | ||||
| GraphPartitionPtr graph_partition_; | |||||
| LinkFuncType lin_convert_; | LinkFuncType lin_convert_; | ||||
| bool is_gevm_convert_; | |||||
| int64_t height_{0}; | int64_t height_{0}; | ||||
| int64_t max_height_{0}; | int64_t max_height_{0}; | ||||
| std::vector<PrimitivePtr> cut_list_; | |||||
| std::unordered_map<AnfNodePtr, int64_t> slots_; | std::unordered_map<AnfNodePtr, int64_t> slots_; | ||||
| InstSet inst_; | InstSet inst_; | ||||
| }; | }; | ||||
| @@ -123,7 +123,6 @@ class CompileGraphs { | |||||
| void Compile(const FuncGraphPtr &func_graph); | void Compile(const FuncGraphPtr &func_graph); | ||||
| FinalVMPtr Link(const FuncGraphPtr &func_graph); | FinalVMPtr Link(const FuncGraphPtr &func_graph); | ||||
| FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | ||||
| static bool ContainMixedTarget(const FuncGraphPtr &graph); | |||||
| private: | private: | ||||
| InstSet insts_; | InstSet insts_; | ||||
| @@ -301,4 +301,20 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| } | } | ||||
| return default_target; | return default_target; | ||||
| } | } | ||||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| for (auto &node : nodes) { | |||||
| if (node->isa<CNode>()) { | |||||
| std::string cur_target = GetCNodeTarget(node); | |||||
| if (last_target != cur_target) { | |||||
| return true; | |||||
| } | |||||
| last_target = cur_target; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -482,6 +482,13 @@ void reset_id(); | |||||
| using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | ||||
| using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | ||||
| std::string GetCNodeTarget(const AnfNodePtr &node); | std::string GetCNodeTarget(const AnfNodePtr &node); | ||||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes); | |||||
| struct GraphSegment { | |||||
| GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {} | |||||
| std::vector<AnfNodePtr> nodes_; | |||||
| bool is_cut_{false}; | |||||
| }; | |||||
| using GraphSegmentPtr = std::shared_ptr<GraphSegment>; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_IR_ANF_H_ | #endif // MINDSPORE_CORE_IR_ANF_H_ | ||||
| @@ -647,6 +647,19 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { | |||||
| return parameter; | return parameter; | ||||
| } | } | ||||
| bool FuncGraph::ContainMultiTarget() const { | |||||
| auto graph_manager = manager(); | |||||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||||
| for (auto &g : graphs) { | |||||
| auto nodes = TopoSort(g->get_return()); | |||||
| if (mindspore::ContainMultiTarget(nodes)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| size_t NewFgSeenGeneration() { | size_t NewFgSeenGeneration() { | ||||
| static size_t fg_seen_generation = 0; | static size_t fg_seen_generation = 0; | ||||
| return ++fg_seen_generation; | return ++fg_seen_generation; | ||||
| @@ -354,6 +354,7 @@ class FuncGraph : public FuncGraphBase { | |||||
| static void set_drawer(Drawer drawer) { drawer_ = drawer; } | static void set_drawer(Drawer drawer) { drawer_ = drawer; } | ||||
| std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | ||||
| void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | ||||
| bool ContainMultiTarget() const; | |||||
| private: | private: | ||||
| // graph is manipulated by manager and others | // graph is manipulated by manager and others | ||||
| @@ -52,21 +52,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { | |||||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | ||||
| BackendPtr b = std::make_shared<Backend>("vm"); | BackendPtr b = std::make_shared<Backend>("vm"); | ||||
| CompileGraph transform_(b); | |||||
| auto splits = transform_.SplitNodes(g); | |||||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||||
| auto segments = graph_partition->Partition(g); | |||||
| VectorRef args({1.0, 2.0}); | VectorRef args({1.0, 2.0}); | ||||
| std::vector<BaseRef> todos(splits.size()); | |||||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||||
| todos.resize(std::distance(todos.begin(), it)); | |||||
| ASSERT_EQ(todos.size(), 1); | |||||
| AnfNodePtrList anf_list; | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||||
| } | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto convertResult = MsVmConvert(segments[0], ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | ||||
| } | } | ||||
| @@ -76,21 +66,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { | |||||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | ||||
| BackendPtr b = std::make_shared<Backend>("vm"); | BackendPtr b = std::make_shared<Backend>("vm"); | ||||
| CompileGraph transform_(b); | |||||
| auto splits = transform_.SplitNodes(g); | |||||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||||
| auto segments = graph_partition->Partition(g); | |||||
| VectorRef args({1.0, 2.0}); | VectorRef args({1.0, 2.0}); | ||||
| std::vector<BaseRef> todos(splits.size()); | |||||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||||
| todos.resize(std::distance(todos.begin(), it)); | |||||
| ASSERT_EQ(todos.size(), 1); | |||||
| AnfNodePtrList anf_list; | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||||
| } | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto convertResult = MsVmConvert(segments[0], ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | ||||
| } | } | ||||
| @@ -100,21 +80,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | ||||
| BackendPtr b = std::make_shared<Backend>("vm"); | BackendPtr b = std::make_shared<Backend>("vm"); | ||||
| CompileGraph transform_(b); | |||||
| auto splits = transform_.SplitNodes(g); | |||||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||||
| auto segments = graph_partition->Partition(g); | |||||
| VectorRef args({1.0, 2.0}); | VectorRef args({1.0, 2.0}); | ||||
| std::vector<BaseRef> todos(splits.size()); | |||||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||||
| todos.resize(std::distance(todos.begin(), it)); | |||||
| ASSERT_EQ(todos.size(), 1); | |||||
| AnfNodePtrList anf_list; | |||||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||||
| } | |||||
| auto convertResult = MsVmConvert(anf_list, ""); | |||||
| auto convertResult = MsVmConvert(segments[0], ""); | |||||
| auto runResult = (*(convertResult.run))(args); | auto runResult = (*(convertResult.run))(args); | ||||
| auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | ||||