| @@ -89,7 +89,8 @@ bool TensorInVector(const VectorRef *outputs) { | |||
| void CompileNodesTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_); | |||
| MS_EXCEPTION_IF_NULL(segment_); | |||
| graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_); | |||
| } | |||
| void CompileGraphTask::Run() { | |||
| @@ -226,10 +227,11 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { | |||
| MsException::GetInstance().CheckException(); | |||
| } | |||
| GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, | |||
| const AnfNodePtrList &outputs) { | |||
| auto task = std::make_shared<CompileNodesTask>(); | |||
| task->session_ = session; | |||
| task->nodes_ = lst; | |||
| task->segment_ = segment; | |||
| task->output_nodes_ = outputs; | |||
| SyncRunTask(task); | |||
| return task->graph_id_; | |||
| @@ -63,7 +63,7 @@ class CompileNodesTask : public Task { | |||
| CompileNodesTask() { type_ = kCompileNodes; } | |||
| ~CompileNodesTask() override = default; | |||
| void Run() override; | |||
| AnfNodePtrList nodes_; | |||
| GraphSegmentPtr segment_; | |||
| AnfNodePtrList output_nodes_; | |||
| GraphId graph_id_{0}; | |||
| }; | |||
| @@ -151,7 +151,7 @@ class Executor { | |||
| ~Executor(); | |||
| void WorkerLoop(); | |||
| void WorkerJoin(); | |||
| GraphId CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraph(const SessionPtr &session, GraphId graphId); | |||
| void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| @@ -1388,9 +1388,9 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve | |||
| return nullptr; | |||
| } | |||
| GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) { | |||
| MS_EXCEPTION_IF_NULL(executor_); | |||
| return executor_->CompileGraph(shared_from_this(), lst, outputs); | |||
| return executor_->CompileGraph(shared_from_this(), segment, outputs); | |||
| } | |||
| GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| @@ -68,7 +68,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual ~SessionBasic() { summary_callback_ = nullptr; } | |||
| GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); | |||
| void BuildGraph(GraphId graphId); | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| @@ -102,6 +102,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {} | |||
| std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs); | |||
| // Get graph by graph id, if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||
| #ifdef ENABLE_DEBUGGER | |||
| // set debugger | |||
| void SetDebugger() { | |||
| @@ -147,8 +149,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | |||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||
| virtual void SetSummaryNodes(KernelGraph *graph); | |||
| @@ -354,7 +354,7 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | |||
| if (func_graph->ContainMultiTarget()) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); | |||
| context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); | |||
| @@ -923,7 +923,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| MS_EXCEPTION_IF_NULL(convert_fn); | |||
| // Convert CNodeList to LinConvertResult. | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| auto runner = convert_fn({app_init}, ""); | |||
| auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false); | |||
| auto runner = convert_fn(segment, ""); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| backend->Link(runner.graph_id); | |||
| } | |||
| @@ -34,30 +34,34 @@ namespace compile { | |||
| bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | |||
| bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | |||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { | |||
| Backend::Backend(const std::string &name) : name_(name) { | |||
| MS_LOG(DEBUG) << "select backend:" << name; | |||
| convert_fn_ = MsVmConvert; | |||
| is_multi_graph_sink_ = false; | |||
| } | |||
| LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) { | |||
| MS_LOG(DEBUG) << "MsConvert"; | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| auto cached = g_ConvertCache.find(lst); | |||
| auto cached = g_ConvertCache.find(segment); | |||
| if (cached != g_ConvertCache.end()) { | |||
| return cached->second; | |||
| } | |||
| LinConvertResult result; | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||
| result.inputs = inputs; | |||
| result.outputs = outputs; | |||
| result.graph_id = kInvalidGraphId; | |||
| GraphId graph_id = kInvalidGraphId; | |||
| if (target != target_device_ && !target.empty()) { | |||
| CreateOtherSession(target); | |||
| graph_id = other_sess_->CompileGraph(lst, outputs); | |||
| graph_id = other_sess_->CompileGraph(segment, outputs); | |||
| } else { | |||
| graph_id = target_sess_->CompileGraph(lst, outputs); | |||
| graph_id = target_sess_->CompileGraph(segment, outputs); | |||
| } | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | |||
| @@ -79,7 +83,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||
| result.graph_id = graph_id; | |||
| graph_id_map_[graph_id] = result; | |||
| (void)g_ConvertCache.emplace(lst, result); | |||
| (void)g_ConvertCache.emplace(segment, result); | |||
| return result; | |||
| } | |||
| @@ -154,12 +158,6 @@ void MsBackend::Link(GraphId graph_id) { | |||
| target_sess_->BuildGraph(graph_id); | |||
| } | |||
| Backend::Backend(const std::string &name) : name_(name) { | |||
| MS_LOG(DEBUG) << "select backend:" << name; | |||
| convert_fn_ = backends[name_]; | |||
| is_multi_graph_sink_ = false; | |||
| } | |||
| MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | |||
| convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); | |||
| target_sess_ = session::SessionFactory::Get().Create(target); | |||
| @@ -194,6 +192,5 @@ VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return | |||
| #ifdef ENABLE_DEBUGGER | |||
| void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } | |||
| #endif | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "utils/contract.h" | |||
| #include "ir/anf.h" | |||
| #include "vm/segment_runner.h" | |||
| #include "vm/graph_partition.h" | |||
| #include "vm/vm.h" | |||
| #include "backend/session/session_basic.h" | |||
| @@ -63,7 +64,7 @@ class MsBackend : public Backend { | |||
| MsBackend(const std::string &name, const std::string &target, uint32_t device_id); | |||
| ~MsBackend() override = default; | |||
| LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); | |||
| LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); | |||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | |||
| VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | |||
| @@ -0,0 +1,447 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "vm/graph_partition.h" | |||
| #include <string> | |||
| #include <functional> | |||
| #include <utility> | |||
| #include <map> | |||
| #include <queue> | |||
| #include <stack> | |||
| #include <set> | |||
| #include "base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| const char kMsConvert[] = "ms"; | |||
| const char kMsVm[] = "vm"; | |||
| const char kGeVm[] = "ge"; | |||
| namespace compile { | |||
| namespace { | |||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(behind_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| if (prior_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[prior_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(prior_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| if (behind_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[behind_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(behind_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||
| int64_t depend_mode = 0; | |||
| if (mode_ptr != nullptr) { | |||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||
| } | |||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> prior_nodes; | |||
| std::vector<AnfNodePtr> behind_nodes; | |||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||
| return; | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| for (auto &second_node : behind_nodes) { | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| auto iter = control_edges->find(second_node); | |||
| if (iter == control_edges->end()) { | |||
| (void)control_edges->insert( | |||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||
| } else { | |||
| iter->second.emplace_back(first_node); | |||
| } | |||
| auto ref_iter = nodes_ref->find(first_node); | |||
| if (ref_iter != nodes_ref->end()) { | |||
| ref_iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||
| std::queue<AnfNodePtr> queue; | |||
| queue.push(graph->get_return()); | |||
| std::set<AnfNodePtr> visited; | |||
| while (!queue.empty()) { | |||
| auto &node = queue.front(); | |||
| queue.pop(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||
| } | |||
| auto iter = nodes_ref->find(input); | |||
| if (iter != nodes_ref->end()) { | |||
| iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||
| } | |||
| if (visited.find(input) != visited.end()) { | |||
| continue; | |||
| } | |||
| visited.insert(input); | |||
| queue.push(input); | |||
| } | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | |||
| std::map<AnfNodePtr, size_t> node_positions; | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Invalid get item node"; | |||
| } | |||
| auto &parent = inputs[1]; | |||
| auto iter = node_positions.find(parent); | |||
| if (iter != node_positions.end()) { | |||
| size_t position = iter->second; | |||
| auto iter_nodes = insert_positions.find(position); | |||
| if (iter_nodes != insert_positions.end()) { | |||
| iter_nodes->second.push_back(node); | |||
| } else { | |||
| (void)insert_positions.insert( | |||
| std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node})); | |||
| } | |||
| continue; | |||
| } | |||
| } | |||
| result.emplace_back(node); | |||
| node_positions[node] = result.size(); | |||
| } | |||
| size_t insert_num = 0; | |||
| for (auto &item : insert_positions) { | |||
| size_t position = item.first + insert_num; | |||
| (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); | |||
| insert_num += item.second.size(); | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::stack<AnfNodePtr> to_visit; | |||
| std::stack<AnfNodePtr> next_to_visit; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||
| std::string handle_target = default_target; | |||
| std::string next_target = ""; | |||
| to_visit.push(graph->get_return()); | |||
| while (!to_visit.empty() || !next_to_visit.empty()) { | |||
| if (to_visit.empty()) { | |||
| to_visit.swap(next_to_visit); | |||
| handle_target = next_target; | |||
| } | |||
| auto &node = to_visit.top(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| to_visit.pop(); | |||
| result.emplace_back(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_inputs = cnode->inputs(); | |||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||
| auto ctrl_inputs = control_edges.find(node); | |||
| if (ctrl_inputs != control_edges.end()) { | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| for (auto &input : node_inputs) { | |||
| auto iter = nodes_ref.find(input); | |||
| if (iter != nodes_ref.end()) { | |||
| iter->second--; | |||
| if (iter->second != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!input->isa<CNode>()) { | |||
| to_visit.push(input); | |||
| continue; | |||
| } | |||
| std::string input_target = GetCNodeTarget(input); | |||
| if (input_target == handle_target) { | |||
| to_visit.push(input); | |||
| } else if (next_to_visit.empty() || input_target == next_target) { | |||
| next_to_visit.push(input); | |||
| next_target = input_target; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support two different target"; | |||
| } | |||
| } | |||
| } | |||
| std::reverse(result.begin(), result.end()); | |||
| return result; | |||
| } | |||
| std::vector<AnfNodePtr> ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::stack<AnfNodePtr> handle_nodes; | |||
| std::stack<AnfNodePtr> next_handle_nodes; | |||
| std::stack<AnfNodePtr> wait_handle_nodes; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||
| std::string handle_target = default_target; | |||
| std::string next_target = ""; | |||
| handle_nodes.push(graph->get_return()); | |||
| while (!handle_nodes.empty() || !next_handle_nodes.empty() || !wait_handle_nodes.empty()) { | |||
| if (handle_nodes.empty()) { | |||
| handle_nodes.swap(next_handle_nodes); | |||
| handle_target.swap(next_target); | |||
| next_handle_nodes.swap(wait_handle_nodes); | |||
| } | |||
| auto &node = handle_nodes.top(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| handle_nodes.pop(); | |||
| result.emplace_back(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_inputs = cnode->inputs(); | |||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||
| auto ctrl_inputs = control_edges.find(node); | |||
| if (ctrl_inputs != control_edges.end()) { | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| std::vector<AnfNodePtr> same_target_ready_inputs; | |||
| std::vector<AnfNodePtr> diff_target_ready_inputs; | |||
| for (auto &input : node_inputs) { | |||
| auto iter = nodes_ref.find(input); | |||
| if (iter != nodes_ref.end()) { | |||
| iter->second--; | |||
| if (iter->second != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!input->isa<CNode>()) { | |||
| handle_nodes.push(input); | |||
| continue; | |||
| } | |||
| std::string input_target = GetCNodeTarget(input); | |||
| if (input_target == handle_target) { | |||
| same_target_ready_inputs.emplace_back(input); | |||
| } else { | |||
| if (next_target.empty()) { | |||
| next_target = input_target; | |||
| } | |||
| if (input_target != next_target) { | |||
| MS_LOG(EXCEPTION) << "Only support two different target"; | |||
| } | |||
| diff_target_ready_inputs.emplace_back(input); | |||
| } | |||
| } | |||
| if (diff_target_ready_inputs.size() == 0) { | |||
| for (auto &input : same_target_ready_inputs) { | |||
| handle_nodes.push(input); | |||
| } | |||
| } else { | |||
| for (auto &input : same_target_ready_inputs) { | |||
| wait_handle_nodes.push(input); | |||
| } | |||
| for (auto &input : diff_target_ready_inputs) { | |||
| next_handle_nodes.push(input); | |||
| } | |||
| } | |||
| } | |||
| std::reverse(result.begin(), result.end()); | |||
| return result; | |||
| } | |||
| bool IsSubGraph(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (!IsValueNode<Primitive>(fn)) { | |||
| return false; | |||
| } | |||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||
| return true; | |||
| } | |||
| } else if (IsValueNode<FuncGraph>(node)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name) | |||
| : cut_list_(cut_list), backend_name_(backend_name) {} | |||
| bool GraphPartition::IsCut(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (IsValueNode<FuncGraph>(fn)) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(fn); | |||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| return false; | |||
| } | |||
| } | |||
| if (!IsValueNode<Primitive>(fn)) { | |||
| return true; | |||
| } | |||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn); | |||
| for (auto &prim : cut_list_) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == node_prim->name()) { | |||
| if (prim->name() == prim::kPrimBpropCut->name()) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true); | |||
| } | |||
| if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||
| if (inputs.size() < 2) { | |||
| return false; | |||
| } | |||
| auto ret = IsSubGraph(inputs[1]); | |||
| return ret; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| #ifdef ENABLE_GE | |||
| if (backend_name_ == kGeVm) { | |||
| auto name = GetCNodeFuncName(cnode); | |||
| auto adpt = transform::DfGraphConvertor::FindAdapter(name); | |||
| if (adpt == nullptr) { | |||
| return true; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| return false; | |||
| } | |||
| std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| bool contain_multi_target = ContainMultiTarget(nodes); | |||
| if (contain_multi_target) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| if (graph != nullptr) { | |||
| nodes = SplitSort(graph, default_target); | |||
| } else { | |||
| nodes = ParallelSplitSort(graph, default_target); | |||
| } | |||
| nodes = OptimizeGetItemOrder(nodes); | |||
| } | |||
| std::vector<GraphSegmentPtr> segments; | |||
| std::vector<AnfNodePtr> segment_nodes; | |||
| std::string last_target; | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| if (segment_nodes.size() != 0) { | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||
| segments.emplace_back(segment); | |||
| segment_nodes.clear(); | |||
| } | |||
| segment_nodes.emplace_back(node); | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, true); | |||
| segments.push_back(segment); | |||
| segment_nodes.clear(); | |||
| } else if (node->isa<CNode>()) { | |||
| if (contain_multi_target) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) { | |||
| auto segment = std::make_shared<GraphSegment>(segment_nodes, false); | |||
| segments.emplace_back(segment); | |||
| segment_nodes.clear(); | |||
| } | |||
| last_target = cur_target; | |||
| } | |||
| segment_nodes.emplace_back(node); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Segment size:" << segments.size(); | |||
| return segments; | |||
| } | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||
| #define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "base/base_ref.h" | |||
| namespace mindspore { | |||
| extern const char kMsVm[]; | |||
| extern const char kGeVm[]; | |||
| extern const char kMsConvert[]; | |||
| namespace compile { | |||
| class GraphPartition { | |||
| public: | |||
| explicit GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name); | |||
| ~GraphPartition() = default; | |||
| std::vector<GraphSegmentPtr> Partition(const FuncGraphPtr &func_graph); | |||
| private: | |||
| bool IsCut(const AnfNodePtr &node); | |||
| std::vector<PrimitivePtr> cut_list_; | |||
| std::string backend_name_; | |||
| }; | |||
| using GraphPartitionPtr = std::shared_ptr<GraphPartition>; | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ | |||
| @@ -34,10 +34,6 @@ | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| const char kMsConvert[] = "ms"; | |||
| const char kMsVm[] = "vm"; | |||
| const char kGeVm[] = "ge"; | |||
| namespace compile { | |||
| // cached conversion | |||
| ConvertCache g_ConvertCache; | |||
| @@ -194,8 +190,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| // This implementation will convert the nodes into a subgraph | |||
| // that will run using the MsVM. | |||
| template <typename T> | |||
| LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||
| auto cached = g_ConvertCache.find(lst); | |||
| LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| auto cached = g_ConvertCache.find(segment); | |||
| if (cached != g_ConvertCache.end()) { | |||
| return cached->second; | |||
| } | |||
| @@ -206,7 +203,7 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||
| // Clone in case g contains subgraphs that have a different manager | |||
| fg = BasicClone(fg); | |||
| @@ -219,18 +216,15 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { | |||
| result.outputs = outputs; | |||
| result.graph_id = UINT32_MAX; | |||
| (void)g_ConvertCache.emplace(lst, result); | |||
| (void)g_ConvertCache.emplace(segment, result); | |||
| return result; | |||
| } | |||
| LinkFuncType MsVmConvert = Convert<VM>; | |||
| std::unordered_map<std::string, LinkFuncType> backends = {{kMsVm, MsVmConvert}}; | |||
| std::set<std::string> backend_list = { | |||
| kMsConvert, | |||
| kMsVm, | |||
| }; | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -27,14 +27,10 @@ | |||
| #include "ir/anf.h" | |||
| #include "vm/vmimpl.h" | |||
| #include "vm/graph_partition.h" | |||
| namespace mindspore { | |||
| extern const char kMsVm[]; | |||
| extern const char kGeVm[]; | |||
| extern const char kMsConvert[]; | |||
| namespace compile { | |||
| struct LinConvertResult { | |||
| RunFuncPtr run; | |||
| RunFuncPtr simu_run; | |||
| @@ -43,11 +39,9 @@ struct LinConvertResult { | |||
| uint32_t graph_id; | |||
| }; | |||
| using LinkFuncType = std::function<LinConvertResult(const AnfNodePtrList &, const std::string &)>; | |||
| using ConvertCache = std::unordered_map<BaseRef, LinConvertResult, BaseRefHash>; | |||
| using LinkFuncType = std::function<LinConvertResult(const GraphSegmentPtr &, const std::string &)>; | |||
| using ConvertCache = std::unordered_map<GraphSegmentPtr, LinConvertResult>; | |||
| extern LinkFuncType MsVmConvert; | |||
| extern LinkFuncType GeVmConvert; | |||
| extern std::unordered_map<std::string, LinkFuncType> backends; | |||
| extern ConvertCache g_ConvertCache; | |||
| extern std::set<std::string> backend_list; | |||
| @@ -21,8 +21,6 @@ | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <queue> | |||
| #include <stack> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -52,386 +50,13 @@ const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||
| return ms_nonlinear_ops; | |||
| } | |||
| namespace { | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (last_target != cur_target) { | |||
| return true; | |||
| } | |||
| last_target = cur_target; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, | |||
| std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) { | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(behind_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| if (prior_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[prior_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { | |||
| prior_nodes->emplace_back(prior_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| if (behind_node->isa<Parameter>()) { | |||
| for (auto &user : node_users[behind_node]) { | |||
| auto cnode = user.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(cnode); | |||
| } | |||
| } | |||
| } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { | |||
| depend_nodes->emplace_back(behind_node); | |||
| } else { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges, | |||
| std::map<AnfNodePtr, size_t> *nodes_ref) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto input_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(input_cnode); | |||
| auto prior_node = input_cnode->input(kControlDependPriorIndex); | |||
| auto depend_node = input_cnode->input(kControlDependBehindIndex); | |||
| MS_EXCEPTION_IF_NULL(prior_node); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim_ptr); | |||
| ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); | |||
| int64_t depend_mode = 0; | |||
| if (mode_ptr != nullptr) { | |||
| depend_mode = GetValue<int64_t>(mode_ptr); | |||
| } | |||
| if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> prior_nodes; | |||
| std::vector<AnfNodePtr> behind_nodes; | |||
| if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { | |||
| return; | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| for (auto &second_node : behind_nodes) { | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| auto iter = control_edges->find(second_node); | |||
| if (iter == control_edges->end()) { | |||
| (void)control_edges->insert( | |||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node})); | |||
| } else { | |||
| iter->second.emplace_back(first_node); | |||
| } | |||
| auto ref_iter = nodes_ref->find(first_node); | |||
| if (ref_iter != nodes_ref->end()) { | |||
| ref_iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { | |||
| std::queue<AnfNodePtr> queue; | |||
| queue.push(graph->get_return()); | |||
| std::set<AnfNodePtr> visited; | |||
| while (!queue.empty()) { | |||
| auto &node = queue.front(); | |||
| queue.pop(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto &input : cnode->inputs()) { | |||
| if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { | |||
| AddControlEdge(graph, input, control_edges, nodes_ref); | |||
| } | |||
| auto iter = nodes_ref->find(input); | |||
| if (iter != nodes_ref->end()) { | |||
| iter->second++; | |||
| } else { | |||
| (void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(input, 1)); | |||
| } | |||
| if (visited.find(input) != visited.end()) { | |||
| continue; | |||
| } | |||
| visited.insert(input); | |||
| queue.push(input); | |||
| } | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> OptimizeGetItemOrder(const std::vector<AnfNodePtr> &nodes) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::map<size_t, std::vector<AnfNodePtr>> insert_positions; | |||
| std::map<AnfNodePtr, size_t> node_positions; | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Invalid get item node"; | |||
| } | |||
| auto &parent = inputs[1]; | |||
| auto iter = node_positions.find(parent); | |||
| if (iter != node_positions.end()) { | |||
| size_t position = iter->second; | |||
| auto iter_nodes = insert_positions.find(position); | |||
| if (iter_nodes != insert_positions.end()) { | |||
| iter_nodes->second.push_back(node); | |||
| } else { | |||
| (void)insert_positions.insert( | |||
| std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node})); | |||
| } | |||
| continue; | |||
| } | |||
| } | |||
| result.emplace_back(node); | |||
| node_positions[node] = result.size(); | |||
| } | |||
| size_t insert_num = 0; | |||
| for (auto &item : insert_positions) { | |||
| size_t position = item.first + insert_num; | |||
| (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); | |||
| insert_num += item.second.size(); | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { | |||
| std::vector<AnfNodePtr> result; | |||
| std::stack<AnfNodePtr> to_visit; | |||
| std::stack<AnfNodePtr> next_to_visit; | |||
| std::map<AnfNodePtr, size_t> nodes_ref; | |||
| std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; | |||
| CalcNodeRefCount(graph, &nodes_ref, &control_edges); | |||
| std::string handle_target = default_target; | |||
| std::string next_target = ""; | |||
| to_visit.push(graph->get_return()); | |||
| while (!to_visit.empty() || !next_to_visit.empty()) { | |||
| if (to_visit.empty()) { | |||
| to_visit.swap(next_to_visit); | |||
| handle_target = next_target; | |||
| } | |||
| auto &node = to_visit.top(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| to_visit.pop(); | |||
| result.emplace_back(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_inputs = cnode->inputs(); | |||
| std::reverse(node_inputs.begin(), node_inputs.end()); | |||
| auto ctrl_inputs = control_edges.find(node); | |||
| if (ctrl_inputs != control_edges.end()) { | |||
| node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); | |||
| } | |||
| for (auto &input : node_inputs) { | |||
| auto iter = nodes_ref.find(input); | |||
| if (iter != nodes_ref.end()) { | |||
| iter->second--; | |||
| if (iter->second != 0) { | |||
| continue; | |||
| } | |||
| } | |||
| if (!input->isa<CNode>()) { | |||
| to_visit.push(input); | |||
| continue; | |||
| } | |||
| std::string input_target = GetCNodeTarget(input); | |||
| if (input_target == handle_target) { | |||
| to_visit.push(input); | |||
| } else if (next_to_visit.empty() || input_target == next_target) { | |||
| next_to_visit.push(input); | |||
| next_target = input_target; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "only support two different target"; | |||
| } | |||
| } | |||
| } | |||
| std::reverse(result.begin(), result.end()); | |||
| return result; | |||
| } | |||
| bool IsSubGraph(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (!IsValueNode<Primitive>(fn)) { | |||
| return false; | |||
| } | |||
| auto node_prim = GetValueNode<PrimitivePtr>(fn); | |||
| if (node_prim->name() == prim::kPrimPartial->name()) { | |||
| return true; | |||
| } | |||
| } else if (IsValueNode<FuncGraph>(node)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) | |||
| : backend_(backend), cut_list_(cut_list) { | |||
| CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) { | |||
| MS_EXCEPTION_IF_NULL(backend_); | |||
| lin_convert_ = backend_->convert_fn(); | |||
| if (lin_convert_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name(); | |||
| } | |||
| is_gevm_convert_ = false; | |||
| if (backend->name() == kGeVm) { | |||
| MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true"; | |||
| is_gevm_convert_ = true; | |||
| } | |||
| } | |||
| bool CompileGraph::IsCut(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (IsValueNode<FuncGraph>(fn)) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(fn); | |||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| return false; | |||
| } | |||
| } | |||
| if (!IsValueNode<Primitive>(fn)) { | |||
| return true; | |||
| } | |||
| PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn); | |||
| for (auto &prim : cut_list_) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == node_prim->name()) { | |||
| if (prim->name() == prim::kPrimBpropCut->name()) { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true); | |||
| } | |||
| if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { | |||
| if (inputs.size() < 2) { | |||
| return false; | |||
| } | |||
| auto ret = IsSubGraph(inputs[1]); | |||
| return ret; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| #ifdef ENABLE_GE | |||
| if (is_gevm_convert_) { | |||
| auto name = GetCNodeFuncName(cnode); | |||
| auto adpt = transform::DfGraphConvertor::FindAdapter(name); | |||
| if (adpt == nullptr) { | |||
| return true; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| return false; | |||
| } | |||
| VectorRef CompileGraph::SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = OptimizeGetItemOrder(input_nodes); | |||
| VectorRef splits; | |||
| VectorRef split; | |||
| std::string last_target; | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| if (split.size() != 0) { | |||
| splits.push_back(split); | |||
| } | |||
| splits.push_back(node); | |||
| split.clear(); | |||
| } else if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (cur_target != last_target && !last_target.empty() && split.size() != 0) { | |||
| splits.push_back(split); | |||
| split.clear(); | |||
| } | |||
| last_target = cur_target; | |||
| split.push_back(node); | |||
| } | |||
| } | |||
| return splits; | |||
| } | |||
| VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); | |||
| if (ContainMultiTarget(nodes)) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| nodes = SplitSort(graph, default_target); | |||
| return SplitNodesWithTarget(nodes, graph); | |||
| } | |||
| VectorRef splits; | |||
| VectorRef split; | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsCut(node)) { | |||
| if (split.size() != 0) { | |||
| splits.push_back(split); | |||
| } | |||
| splits.push_back(node); | |||
| split.clear(); | |||
| } else if (node->isa<CNode>()) { | |||
| split.push_back(node); | |||
| } | |||
| } | |||
| return splits; | |||
| graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name()); | |||
| } | |||
| // Push the value node on the stack. | |||
| @@ -512,12 +137,12 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { | |||
| } | |||
| } | |||
| int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, | |||
| const std::string &target) { | |||
| int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| MS_LOG(DEBUG) << "LinConvert start"; | |||
| LinConvertResult result; | |||
| result = lin_convert_(node_list, target); | |||
| result = lin_convert_(segment, target); | |||
| if (result.run == nullptr) { | |||
| MS_LOG(ERROR) << "LinConvert failed"; | |||
| @@ -583,25 +208,23 @@ int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &n | |||
| return RET_SUCCESS; | |||
| } | |||
| bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||
| bool CompileGraph::Compile(const FuncGraphPtr &graph) { | |||
| MS_LOG(DEBUG) << "Start split graph"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| VectorRef splits = SplitNodes(graph); | |||
| MS_EXCEPTION_IF_NULL(graph_partition_); | |||
| auto segments = graph_partition_->Partition(graph); | |||
| MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); | |||
| for (auto &split : splits) { | |||
| MS_LOG(DEBUG) << "Split nodes size:" << segments.size(); | |||
| for (auto &segment : segments) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| int64_t ret = RET_SUCCESS; | |||
| if (utils::isa<VectorRef>(split)) { | |||
| if (!segment->is_cut_) { | |||
| MS_LOG(DEBUG) << "Start a extern LinConvert"; | |||
| std::vector<AnfNodePtr> args; | |||
| auto vec_ref = utils::cast<VectorRef>(split); | |||
| (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), | |||
| [](const BaseRef &v) { return utils::cast<AnfNodePtr>(v); }); | |||
| if (args.size() > 0) { | |||
| std::string cur_target = GetCNodeTarget(args[0]); | |||
| ret = LinConvert(graph, args, cur_target); | |||
| if (segment->nodes_.size() > 0) { | |||
| std::string cur_target = GetCNodeTarget(segment->nodes_[0]); | |||
| ret = LinConvert(graph, segment, cur_target); | |||
| } else { | |||
| ret = LinConvert(graph, args); | |||
| ret = LinConvert(graph, segment); | |||
| } | |||
| MS_LOG(DEBUG) << "End a extern LinConvert"; | |||
| if (ret == RET_FAILED) { | |||
| @@ -612,10 +235,11 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "Start a cut node"; | |||
| if (!(utils::isa<AnfNodePtr>(split) && utils::cast<AnfNodePtr>(split)->isa<CNode>())) { | |||
| auto &cut_node = segment->nodes_[0]; | |||
| if (!cut_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info()); | |||
| } | |||
| CNodePtr node = utils::cast<AnfNodePtr>(split)->cast<CNodePtr>(); | |||
| CNodePtr node = cut_node->cast<CNodePtr>(); | |||
| ret = InterpretNode(graph, node); | |||
| MS_LOG(DEBUG) << "End a cut node"; | |||
| if (ret == RET_BREAK) { | |||
| @@ -635,7 +259,7 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) { | |||
| int64_t param_height = height_; | |||
| MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); | |||
| if (!SplitGraph(graph)) { | |||
| if (!Compile(graph)) { | |||
| return inst_; | |||
| } | |||
| @@ -897,20 +521,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { | |||
| return rt; | |||
| } | |||
| bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto graph_manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||
| for (auto &g : graphs) { | |||
| auto nodes = TopoSort(g->get_return()); | |||
| if (ContainMultiTarget(nodes)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| BackendPtr CreateBackend() { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -31,6 +31,7 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "vm/segment_runner.h" | |||
| #include "vm/backend.h" | |||
| #include "vm/graph_partition.h" | |||
| // mindspore namespace is the top level namespace of MindSpore project. | |||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | |||
| @@ -59,7 +60,6 @@ class CompileGraph { | |||
| void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } | |||
| void Ret(int64_t nargs); | |||
| int64_t Ref(const AnfNodePtr &node); | |||
| VectorRef SplitNodes(const FuncGraphPtr &func_graph); | |||
| void set_height(int64_t h) { | |||
| height_ = h; | |||
| @@ -76,10 +76,9 @@ class CompileGraph { | |||
| } | |||
| private: | |||
| VectorRef SplitNodesWithTarget(const std::vector<AnfNodePtr> &input_nodes, const FuncGraphPtr &graph); | |||
| void PushParameters(const FuncGraphPtr &func_graph); | |||
| bool SplitGraph(const FuncGraphPtr &func_graph); | |||
| int64_t LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||
| bool Compile(const FuncGraphPtr &func_graph); | |||
| int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); | |||
| int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | |||
| int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | |||
| void AddPadStack(int64_t param_height); | |||
| @@ -97,11 +96,12 @@ class CompileGraph { | |||
| void AddInst(const Instruction &inst, const VectorRef &args); | |||
| BackendPtr backend_; | |||
| GraphPartitionPtr graph_partition_; | |||
| LinkFuncType lin_convert_; | |||
| bool is_gevm_convert_; | |||
| int64_t height_{0}; | |||
| int64_t max_height_{0}; | |||
| std::vector<PrimitivePtr> cut_list_; | |||
| std::unordered_map<AnfNodePtr, int64_t> slots_; | |||
| InstSet inst_; | |||
| }; | |||
| @@ -123,7 +123,6 @@ class CompileGraphs { | |||
| void Compile(const FuncGraphPtr &func_graph); | |||
| FinalVMPtr Link(const FuncGraphPtr &func_graph); | |||
| FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); | |||
| static bool ContainMixedTarget(const FuncGraphPtr &graph); | |||
| private: | |||
| InstSet insts_; | |||
| @@ -301,4 +301,20 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| } | |||
| return default_target; | |||
| } | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| for (auto &node : nodes) { | |||
| if (node->isa<CNode>()) { | |||
| std::string cur_target = GetCNodeTarget(node); | |||
| if (last_target != cur_target) { | |||
| return true; | |||
| } | |||
| last_target = cur_target; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -482,6 +482,13 @@ void reset_id(); | |||
| using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>; | |||
| using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>; | |||
| std::string GetCNodeTarget(const AnfNodePtr &node); | |||
| bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes); | |||
| struct GraphSegment { | |||
| GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {} | |||
| std::vector<AnfNodePtr> nodes_; | |||
| bool is_cut_{false}; | |||
| }; | |||
| using GraphSegmentPtr = std::shared_ptr<GraphSegment>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_ANF_H_ | |||
| @@ -647,6 +647,19 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { | |||
| return parameter; | |||
| } | |||
| bool FuncGraph::ContainMultiTarget() const { | |||
| auto graph_manager = manager(); | |||
| MS_EXCEPTION_IF_NULL(graph_manager); | |||
| FuncGraphSet graphs = graph_manager->func_graphs(); | |||
| for (auto &g : graphs) { | |||
| auto nodes = TopoSort(g->get_return()); | |||
| if (mindspore::ContainMultiTarget(nodes)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| size_t NewFgSeenGeneration() { | |||
| static size_t fg_seen_generation = 0; | |||
| return ++fg_seen_generation; | |||
| @@ -354,6 +354,7 @@ class FuncGraph : public FuncGraphBase { | |||
| static void set_drawer(Drawer drawer) { drawer_ = drawer; } | |||
| std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; } | |||
| void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; } | |||
| bool ContainMultiTarget() const; | |||
| private: | |||
| // graph is manipulated by manager and others | |||
| @@ -52,21 +52,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | |||
| BackendPtr b = std::make_shared<Backend>("vm"); | |||
| CompileGraph transform_(b); | |||
| auto splits = transform_.SplitNodes(g); | |||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||
| auto segments = graph_partition->Partition(g); | |||
| VectorRef args({1.0, 2.0}); | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto convertResult = MsVmConvert(segments[0], ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 3.0); | |||
| } | |||
| @@ -76,21 +66,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | |||
| BackendPtr b = std::make_shared<Backend>("vm"); | |||
| CompileGraph transform_(b); | |||
| auto splits = transform_.SplitNodes(g); | |||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||
| auto segments = graph_partition->Partition(g); | |||
| VectorRef args({1.0, 2.0}); | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto convertResult = MsVmConvert(segments[0], ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| ASSERT_TRUE(runResult.size() == 1 && py::cast<double>(BaseRefToPyData(runResult[0])) == 2.0); | |||
| } | |||
| @@ -100,21 +80,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(g); | |||
| BackendPtr b = std::make_shared<Backend>("vm"); | |||
| CompileGraph transform_(b); | |||
| auto splits = transform_.SplitNodes(g); | |||
| auto graph_partition = std::make_shared<GraphPartition>(nonlinear_ops, b->name()); | |||
| auto segments = graph_partition->Partition(g); | |||
| VectorRef args({1.0, 2.0}); | |||
| std::vector<BaseRef> todos(splits.size()); | |||
| auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), | |||
| [](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); }); | |||
| todos.resize(std::distance(todos.begin(), it)); | |||
| ASSERT_EQ(todos.size(), 1); | |||
| AnfNodePtrList anf_list; | |||
| for (auto &item : utils::cast<VectorRef>(todos[0])) { | |||
| anf_list.push_back(utils::cast<AnfNodePtr>(item)); | |||
| } | |||
| auto convertResult = MsVmConvert(anf_list, ""); | |||
| auto convertResult = MsVmConvert(segments[0], ""); | |||
| auto runResult = (*(convertResult.run))(args); | |||
| auto result = py::cast<bool>(BaseRefToPyData(runResult[0])); | |||