Merge pull request !3110 from zhoufeng/delete-deprecated-codestags/v0.7.0-beta
| @@ -51,26 +51,16 @@ class AscendSession : public SessionBasic { | |||
| py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) override; | |||
| // set parameters of final graph | |||
| GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override; | |||
| // set output of final graph | |||
| void SetFinalGraphOutput(const BaseRef &output) override; | |||
| // insert switch and set the relative active ops | |||
| void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; | |||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | |||
| void SetChildGraphInput(GraphId g, const VectorRef &args) override; | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; | |||
| // get graph id of final graph | |||
| GraphId GetFinalRunGraph() const override { return final_graph_id_; } | |||
| // insert active to graph | |||
| void SetActive(GraphId, GraphId) override; | |||
| // compile child graph when session have multiple child graphs | |||
| void CompileChildGraph(const KernelGraphPtr &child_graph); | |||
| void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary); | |||
| void GetSummaryNodes(KernelGraph *graph); | |||
| private: | |||
| void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary); | |||
| void SetSummaryNodes(KernelGraph *graph) override; | |||
| void InitRuntimeResource(); | |||
| void SelectKernel(const KernelGraph &kernel_graph) const; | |||
| void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| @@ -92,63 +82,21 @@ class AscendSession : public SessionBasic { | |||
| void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; | |||
| void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); | |||
| size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); | |||
| void SetFinalGraphOutput(const AnfNodePtr &node); | |||
| void SetFinalGraphOutput(const ValuePtr &value); | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| // split graphs with recurse from root graph | |||
| void SplitGraphs(NotNull<KernelGraphPtr> root_graph); | |||
| void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | |||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | |||
| static void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | |||
| std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list); | |||
| void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| // insert assion op to sync data bettween different graphs | |||
| void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); | |||
| // insert mutiple assigns to graph | |||
| void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); | |||
| // insert active op to graph | |||
| void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); | |||
| // get execute index of graph | |||
| size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); | |||
| // handle condition graph from vm | |||
| void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); | |||
| // insert depend to graph, used to attch control nodes to graph | |||
| void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); | |||
| // insert depend to graph, used to attch control nodes to graph | |||
| void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); | |||
| // set child graph parameter if front arg is a anf | |||
| void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); | |||
| // set child graph parameter if front arg is a tensor | |||
| void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); | |||
| // update the execution order of all child graphs | |||
| void UpdateGraphOrder(GraphId to_graph); | |||
| // handle switch when merge | |||
| void MergeSwitchCompile(); | |||
| // get graph order vector by graph id | |||
| std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id); | |||
| const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const; | |||
| // get graph order type vector by graph id | |||
| std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id); | |||
| // copy output of if and else | |||
| void CopyOutputOfIf(GraphId false_graph_id); | |||
| const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const; | |||
| // check if graph cache exist | |||
| bool GraphCacheExist(const GraphInfo &graph_info) const; | |||
| // insert all assign to child graph | |||
| void InsertAllAssigns(); | |||
| // create fake output of final graph | |||
| AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); | |||
| // sync intial tensors' data to device | |||
| void SyncInitialTenosrToDevice(); | |||
| void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| @@ -162,16 +110,10 @@ class AscendSession : public SessionBasic { | |||
| void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| // member variables | |||
| // key is final_graph_id,value is child graph execute order of final graph | |||
| std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_; | |||
| // key is final_graph_id,value is the graph types of child graphs | |||
| std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_; | |||
| // record condition graph of while | |||
| std::unordered_map<GraphId, GraphId> while_condition_graphs_; | |||
| // record all conditions | |||
| std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_; | |||
| std::unordered_map<GraphId, AnfNodePtr> condition_output_; | |||
| // share parameters | |||
| std::vector<std::tuple<AnfNodePtr, GraphId, size_t>> assigns_; | |||
| // initial tensors, these tensor will sync data to device before run graph | |||
| @@ -108,7 +108,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten | |||
| kernel_graph->set_execution_order(execution_order); | |||
| NamedSummaryOutputs summary_outputs; | |||
| if (enable_summary) { | |||
| GetSummaryNodes(kernel_graph.get()); | |||
| SetSummaryNodes(kernel_graph.get()); | |||
| summary_outputs = kernel_graph->summary_nodes(); | |||
| runtime_.IncreaseSummaryRefCount(summary_outputs); | |||
| } | |||
| @@ -217,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList | |||
| Reorder(&execution_order); | |||
| graph->set_execution_order(execution_order); | |||
| // Get summary nodes. | |||
| GetSummaryNodes(graph.get()); | |||
| SetSummaryNodes(graph.get()); | |||
| // Remove NoOp from execution graph | |||
| opt::RemoveNopNode(graph.get()); | |||
| // Set graph manager. | |||
| @@ -898,27 +898,6 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP | |||
| std::queue<AnfNodePtr> seed_nodes; | |||
| UpdateNodeEdgeList(&seed_nodes); | |||
| } | |||
| // update graph inputs in child graph | |||
| auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), | |||
| [&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { | |||
| return n.first == old_anf_node.get(); | |||
| }); | |||
| if (it_real_inputs != real_inputs_.end()) { | |||
| // erase old parameter in map | |||
| auto old_args = it_real_inputs->second; | |||
| real_inputs_.erase(it_real_inputs); | |||
| // insert new parameter to map | |||
| auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), | |||
| [&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { | |||
| return n.first == new_anf_node.get(); | |||
| }); | |||
| if (iter != real_inputs_.end()) { | |||
| MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; | |||
| iter->second = old_args; | |||
| } else { | |||
| real_inputs_.emplace_back(new_anf_node, old_args); | |||
| } | |||
| } | |||
| } | |||
| void KernelGraph::UpdateExecuteKernelStreamLabel() { | |||
| @@ -953,56 +932,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi | |||
| return result; | |||
| } | |||
| void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto iter = std::find_if( | |||
| real_inputs_.begin(), real_inputs_.end(), | |||
| [¶meter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; }); | |||
| if (iter != real_inputs_.end()) { | |||
| auto &args = iter->second; | |||
| args.push_back(arg); | |||
| } else { | |||
| real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg)); | |||
| } | |||
| } | |||
| void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) { | |||
| unreuse_args_[arg] = from_graph; | |||
| } | |||
| void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | |||
| std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map; | |||
| for (auto &it : real_inputs_) { | |||
| auto parameter = it.first; | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto real_inputs = it.second; | |||
| std::vector<AnfNodePtr> new_real_inputs; | |||
| for (auto &real_input : real_inputs) { | |||
| // if real input is a call node ,find the child graph output act as the new real input | |||
| auto tmp_real_input = GetCallRealOutputs(real_input); | |||
| std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); | |||
| // replace the call in unreuse_args_ | |||
| auto unreuse_arg_it = unreuse_args_.find(real_input); | |||
| if (unreuse_arg_it != unreuse_args_.end()) { | |||
| auto old_graph = unreuse_arg_it->second; | |||
| for (auto new_real_input : new_real_inputs) { | |||
| // if call reference graph output is parameter, it will be allowed to reuse | |||
| if (!new_real_input->isa<Parameter>()) { | |||
| unreuse_args_[new_real_input] = old_graph; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| real_inputs_map.emplace_back(parameter, new_real_inputs); | |||
| } | |||
| real_inputs_ = real_inputs_map; | |||
| } | |||
| void KernelGraph::PrintGraphExecuteOrder() const { | |||
| MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; | |||
| for (size_t i = 0; i < execution_order_.size(); i++) { | |||
| @@ -131,16 +131,8 @@ class KernelGraph : public FuncGraph { | |||
| void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } | |||
| // find anf node in graph | |||
| std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | |||
| // get real inputs | |||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; } | |||
| void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | |||
| // mark unreused args | |||
| void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph); | |||
| const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; } | |||
| // used to dump ir | |||
| std::string ToString() const override; | |||
| // update the real input if the node is a call | |||
| void UpdateCallRealInput(); | |||
| void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } | |||
| CNodePtr get_start_label() { return start_label_; } | |||
| @@ -212,9 +204,6 @@ class KernelGraph : public FuncGraph { | |||
| // valid inputs | |||
| std::vector<bool> valid_inputs_; | |||
| // new members for control sink process | |||
| // all child grahs refers to partial node | |||
| std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_; | |||
| // child graph execute order in root graph | |||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; | |||
| @@ -223,9 +212,6 @@ class KernelGraph : public FuncGraph { | |||
| // parameter graph | |||
| std::shared_ptr<KernelGraph> parent_graph_; | |||
| // record real parameters,inputs_ is the formal parameters | |||
| std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_; | |||
| std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_; | |||
| CNodePtr start_label_; | |||
| CNodePtr end_goto_; | |||
| @@ -890,7 +890,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { | |||
| void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } | |||
| void SessionBasic::GetSummaryNodes(KernelGraph *graph) { | |||
| void SessionBasic::SetSummaryNodes(KernelGraph *graph) { | |||
| MS_LOG(DEBUG) << "Update summary Start"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (!graph->summary_node_exist()) { | |||
| @@ -930,7 +930,7 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| if (!exist_summary) { | |||
| return; | |||
| } | |||
| GetSummaryNodes(graph); | |||
| SetSummaryNodes(graph); | |||
| auto summary_outputs = graph->summary_nodes(); | |||
| std::map<std::string, tensor::TensorPtr> params_list; | |||
| // fetch outputs apply kernel in session & run callback functions | |||
| @@ -92,19 +92,9 @@ class SessionBasic { | |||
| CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); | |||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | |||
| // set parameters of final graph | |||
| virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } | |||
| // set output of final graph | |||
| virtual void SetFinalGraphOutput(const BaseRef &) {} | |||
| // insert switch and set the relative active ops | |||
| virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} | |||
| // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter | |||
| virtual void SetChildGraphInput(GraphId, const VectorRef &) {} | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | |||
| virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | |||
| virtual void SetActive(GraphId, GraphId) {} | |||
| virtual void GetSummaryNodes(KernelGraph *graph); | |||
| void AssignParamKey(const KernelGraphPtr &kernel_graph); | |||
| void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | |||
| virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; } | |||
| @@ -120,6 +110,7 @@ class SessionBasic { | |||
| #endif | |||
| protected: | |||
| virtual void SetSummaryNodes(KernelGraph *graph); | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| KernelGraphPtr GetGraph(GraphId graph_id) const; | |||
| virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| @@ -1,4 +1,4 @@ | |||
| file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") | |||
| file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute.cc") | |||
| if (ENABLE_GE) | |||
| file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") | |||
| @@ -21,7 +21,6 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/callbacks.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/base_ref_extends.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "common/utils.h" | |||
| @@ -34,19 +33,6 @@ namespace compile { | |||
| bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } | |||
| bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); } | |||
| LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { | |||
| // multi_graph merge to one, big graph have paramters in begin and only have one output | |||
| MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size(); | |||
| multi_result_.inputs = g->parameters(); | |||
| final_output_ = NewValueNode("fake_output"); | |||
| multi_result_.outputs = {final_output_}; | |||
| GraphId final_g = target_sess_->GetFinalRunGraph(); | |||
| multi_result_.run = std::make_shared<RunFunc>( | |||
| [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); }); | |||
| return multi_result_; | |||
| } | |||
| LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { | |||
| MS_LOG(DEBUG) << "MsConvert"; | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| @@ -96,149 +82,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri | |||
| return result; | |||
| } | |||
| void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) { | |||
| GraphId active_g = simu_cond_map_[c].cond_graph_map[cond]; | |||
| GraphId cond_g = kInvalidGraphId; | |||
| if (utils::isa<AnfNodePtr>(c)) { | |||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString(); | |||
| } | |||
| auto before_cond = curr_switch_; | |||
| if (curr_switch_.hash() != c.hash()) { | |||
| // invoke while false->before true call | |||
| if (simu_cond_map_[before_cond].cond_graph_map.count(false)) { | |||
| active_g = simu_cond_map_[before_cond].cond_graph_map[false]; | |||
| } else { | |||
| active_g = kInvalidGraphId; | |||
| } | |||
| // while x < y: | |||
| // z = y + 1 | |||
| // while z < c2: | |||
| // out = out + 1 | |||
| // z = z + 1 | |||
| if (active_g == cond_g) { | |||
| active_g = kInvalidGraphId; | |||
| simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId; | |||
| } | |||
| MS_LOG(DEBUG) << "invoke set active:" << active_g; | |||
| } | |||
| MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g; | |||
| target_sess_->SetActive(active_g, cond_g); | |||
| } | |||
| void MsBackend::SetSwitchGraph() { | |||
| MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString(); | |||
| if (is_switch_call_) { | |||
| GraphId false_g = kInvalidGraphId; | |||
| GraphId true_g = kInvalidGraphId; | |||
| MS_LOG(DEBUG) << "start SetSwitchGraph"; | |||
| true_g = simu_cond_map_[curr_switch_].cond_graph_map[true]; | |||
| bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; | |||
| if (!curr_cond) { | |||
| if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) { | |||
| // has false branch | |||
| false_g = simu_cond_map_[curr_switch_].cond_graph_map[false]; | |||
| } | |||
| GraphId cond_g = kInvalidGraphId; | |||
| if (utils::isa<AnfNodePtr>(curr_switch_)) { | |||
| cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; | |||
| target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_)); | |||
| } | |||
| is_switch_call_ = false; | |||
| MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; | |||
| } | |||
| } | |||
| // convert node from formal parameter to actual parameter, | |||
| // and actual parameter is graph user's formal parameter. | |||
| // get top while graph's parameter in recall while. | |||
| AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| std::unordered_map<AnfNodePtr, size_t> params_index; | |||
| auto result = node; | |||
| auto graph = result->func_graph(); | |||
| while (func_graph != graph) { | |||
| auto iter = graph_user_inputs_.find(graph); | |||
| if (iter == graph_user_inputs_.end()) { | |||
| break; | |||
| } | |||
| params_index.clear(); | |||
| auto ¶ms = graph->parameters(); | |||
| for (size_t i = 0; i < params.size(); ++i) { | |||
| params_index[params[i]] = i; | |||
| } | |||
| graph = iter->second.first; | |||
| auto &inputs = iter->second.second; | |||
| result = inputs[params_index[result]]; | |||
| } | |||
| return result; | |||
| } | |||
| void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user, | |||
| const AnfNodePtrList &inputs) { | |||
| if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) { | |||
| return; | |||
| } | |||
| graph_user_inputs_[func_graph] = {user, inputs}; | |||
| } | |||
| void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) { | |||
| std::unordered_map<AnfNodePtr, size_t> params_index; | |||
| auto ¶ms = func_graph->parameters(); | |||
| for (size_t i = 0; i < params.size(); ++i) { | |||
| params_index[params[i]] = i; | |||
| } | |||
| // recall all child graphs in this while | |||
| auto &graph_inputs = graph_inputs_[c]; | |||
| for (auto &iter : graph_inputs) { | |||
| auto &graph = iter.first; | |||
| auto &old_args = iter.second; | |||
| auto &result = graph_id_map_[graph]; | |||
| auto &inputs = result.inputs; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| auto input = ConvertGraphInput(func_graph, inputs[i]); | |||
| auto it = params_index.find(input); | |||
| if (it != params_index.end()) { | |||
| old_args[i] = args[it->second]; | |||
| } | |||
| } | |||
| target_sess_->SetChildGraphInput(graph, old_args); | |||
| } | |||
| graph_inputs_.erase(c); | |||
| } | |||
| // compile set input output | |||
| VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "set graph input:" << g; | |||
| // switch maybe twice | |||
| target_sess_->SetChildGraphInput(g, args); | |||
| if (is_switch_call_) { | |||
| if (!curr_switch_.is_null()) { | |||
| // push this {g, args} to all user while graph_inputs for nest while, | |||
| // when current condition recall over delete this cond in graph_inputs. | |||
| for (auto &iter : graph_inputs_) { | |||
| iter.second.push_back({g, args}); | |||
| } | |||
| if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) { | |||
| graph_inputs_[curr_switch_].push_back({g, args}); | |||
| } | |||
| } | |||
| bool curr_cond = simu_cond_map_[curr_switch_].curr_cond; | |||
| MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g; | |||
| simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g; | |||
| SetSwitchGraph(); | |||
| } | |||
| std::vector<BaseRef> outputs; | |||
| (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs), | |||
| [](const AnfNodePtr &v) { return v; }); | |||
| @@ -290,36 +136,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s | |||
| return outputs; | |||
| } | |||
| SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) { | |||
| MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size(); | |||
| CondGraph cond_graph; | |||
| cond_graph.curr_cond = value; | |||
| if (simu_cond_map_.find(c) == simu_cond_map_.end()) { | |||
| simu_cond_map_[c] = cond_graph; | |||
| } | |||
| if (simu_cond_map_[c].cond_graph_map.count(value)) { | |||
| return kCondAlreadyRun; | |||
| } | |||
| simu_cond_map_[c].curr_cond = value; | |||
| MS_LOG(DEBUG) << "end set cond "; | |||
| return kCondOk; | |||
| } | |||
| void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) { | |||
| MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size(); | |||
| std::vector<BaseRef> args; | |||
| auto parameters = root->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args), | |||
| [](const AnfNodePtr &v) { return v; }); | |||
| MS_LOG(DEBUG) << "Simulate start"; | |||
| (void)target_sess_->SetFinalGraphInput(parameters); | |||
| BaseRef output = rt->Eval(VectorRef(args)); | |||
| target_sess_->SetFinalGraphOutput(output); | |||
| MS_LOG(DEBUG) << "Simulate Eval end"; | |||
| } | |||
| void MsBackend::Link(GraphId graph_id) { | |||
| if (graph_id == kInvalidGraphId) { | |||
| graph_id = target_sess_->GetFinalRunGraph(); | |||
| @@ -330,9 +146,6 @@ void MsBackend::Link(GraphId graph_id) { | |||
| Backend::Backend(const std::string &name) : name_(name) { | |||
| MS_LOG(DEBUG) << "select backend:" << name; | |||
| convert_fn_ = backends[name_]; | |||
| is_switch_call_ = false; | |||
| is_multi_graph_sink_ = false; | |||
| simu_flag_ = false; | |||
| } | |||
| MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { | |||
| @@ -43,50 +43,19 @@ class Backend { | |||
| LinkFuncType convert_fn() { return convert_fn_; } | |||
| std::string name() { return name_; } | |||
| virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} | |||
| virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } | |||
| virtual bool GetCond(const BaseRef &c, bool *value); | |||
| virtual bool GetIndex(const BaseRef &c, int *value); | |||
| virtual void SetSwitchGraph() {} | |||
| virtual void SetSwitchActive(const BaseRef &, bool) {} | |||
| virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} | |||
| virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} | |||
| virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; } | |||
| void set_curr_switch(const BaseRef &value) { | |||
| curr_switch_ = value; | |||
| is_switch_call_ = true; | |||
| } | |||
| BaseRef curr_switch() { return curr_switch_; } | |||
| virtual void Link(GraphId) {} | |||
| virtual LinConvertResult GetMultiGraphRun(const FuncGraphPtr &) { return LinConvertResult(); } | |||
| virtual void SetDebugger() {} | |||
| LinConvertResult multi_result() { return multi_result_; } | |||
| void set_multi_result(const LinConvertResult &value) { multi_result_ = value; } | |||
| AnfNodePtr final_output() const { return final_output_; } | |||
| bool is_multi_graph_sink() const { return is_multi_graph_sink_; } | |||
| void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } | |||
| bool simu_flag() const { return simu_flag_; } | |||
| bool is_switch_call() const { return is_switch_call_; } | |||
| void set_simu_flag(bool simu) { simu_flag_ = simu; } | |||
| virtual void SetDebugger() {} | |||
| protected: | |||
| std::string name_; | |||
| LinkFuncType convert_fn_; | |||
| BaseRef curr_switch_; // curr switch node | |||
| bool is_multi_graph_sink_; | |||
| bool is_switch_call_; | |||
| bool simu_flag_; | |||
| LinConvertResult multi_result_; | |||
| AnfNodePtr final_output_; | |||
| std::unordered_map<FuncGraphPtr, std::pair<FuncGraphPtr, AnfNodePtrList>> graph_user_inputs_; | |||
| }; | |||
| struct CondGraph { | |||
| bool curr_cond; | |||
| std::unordered_map<bool, GraphId> cond_graph_map; | |||
| }; | |||
| class MsBackend : public Backend { | |||
| @@ -98,16 +67,7 @@ class MsBackend : public Backend { | |||
| VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); | |||
| VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); | |||
| void SimulateRun(FinalVMPtr rt, FuncGraphPtr root) override; | |||
| SwitchCondStatus SetSimuCond(const BaseRef &c, bool value) override; | |||
| void SetSwitchGraph() override; | |||
| void SetSwitchActive(const BaseRef &c, bool cond) override; | |||
| void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) override; | |||
| void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) override; | |||
| void Link(GraphId) override; | |||
| AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); | |||
| LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; | |||
| GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override; | |||
| VectorRef RunGraph(GraphId graph_id, const VectorRef &args); | |||
| void CreateOtherSession(const std::string &target); | |||
| @@ -121,9 +81,7 @@ class MsBackend : public Backend { | |||
| session::SessionPtr other_sess_; | |||
| std::string target_device_; | |||
| std::string other_device_; | |||
| std::unordered_map<BaseRef, CondGraph, BaseRefHash> simu_cond_map_; | |||
| std::unordered_map<GraphId, LinConvertResult> graph_id_map_; | |||
| std::unordered_map<BaseRef, std::list<std::pair<GraphId, VectorRef>>, BaseRefHash> graph_inputs_; | |||
| }; | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -515,11 +515,7 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no | |||
| MS_LOG(DEBUG) << "LinConvert start"; | |||
| LinConvertResult result; | |||
| if (backend_->simu_flag()) { | |||
| result = backend_->GetMultiGraphRun(graph); | |||
| } else { | |||
| result = lin_convert_(node_list, target); | |||
| } | |||
| result = lin_convert_(node_list, target); | |||
| if (result.run == nullptr) { | |||
| MS_LOG(ERROR) << "LinConvert failed"; | |||
| @@ -546,27 +542,6 @@ int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &no | |||
| return RET_SUCCESS; | |||
| } | |||
| void CompileGraph::AddSinkSwitch(const CNodePtr &node) { | |||
| MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString(); | |||
| if (backend_->is_multi_graph_sink()) { | |||
| VectorRef args; | |||
| args.emplace_back(-1); | |||
| MS_LOG(DEBUG) << "call::" << height_; | |||
| AddInst(Instruction::kCall, args); | |||
| args.clear(); | |||
| args.emplace_back(node->input(1)); | |||
| AddInst(Instruction::kSwitchReturn, args); | |||
| args.clear(); | |||
| args.emplace_back(false); | |||
| args.emplace_back(Ref(node->input(1))); | |||
| args.emplace_back(Ref(node->input(2))); | |||
| args.emplace_back(Ref(node->input(3))); | |||
| AddInst(Instruction::kSwitch, args); | |||
| } | |||
| } | |||
| int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true); | |||
| @@ -589,7 +564,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) | |||
| AddPartial(node); | |||
| } else if (IsPrimitive(fn, prim::kPrimSwitch)) { | |||
| AddSwitch(node); | |||
| AddSinkSwitch(node); | |||
| } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { | |||
| AddSwitchLayer(node); | |||
| } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { | |||
| @@ -607,14 +581,6 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) | |||
| return RET_SUCCESS; | |||
| } | |||
| void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) { | |||
| auto ret = LinConvert(graph, {}); | |||
| if (ret == RET_FAILED) { | |||
| MS_LOG(EXCEPTION) << "MultiGraphRun failed."; | |||
| } | |||
| AddReturn(nullptr); | |||
| } | |||
| bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||
| MS_LOG(DEBUG) << "Start split graph"; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -659,11 +625,6 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { | |||
| return true; | |||
| } | |||
| InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) { | |||
| InstSet inst = Run(graph); | |||
| return inst; | |||
| } | |||
| InstSet CompileGraph::Run(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -672,12 +633,8 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) { | |||
| int param_height = height_; | |||
| MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); | |||
| if (backend_->simu_flag()) { | |||
| GenMultiGraphsRun(graph); | |||
| } else { | |||
| if (!SplitGraph(graph)) { | |||
| return inst_; | |||
| } | |||
| if (!SplitGraph(graph)) { | |||
| return inst_; | |||
| } | |||
| AddPadStack(param_height); | |||
| @@ -712,12 +669,6 @@ void CompileGraph::AddPartial(const CNodePtr &node) { | |||
| if (!IsValueNode<FuncGraph>(fn)) { | |||
| MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph"; | |||
| } | |||
| if (backend_->is_multi_graph_sink()) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(fn); | |||
| args.emplace_back(func_graph); | |||
| AnfNodePtrList outs(inputs.begin() + 2, inputs.end()); | |||
| backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); | |||
| } | |||
| for (size_t i = 1; i < inputs.size(); i++) { | |||
| args.emplace_back(Ref(inputs[i])); | |||
| } | |||
| @@ -739,9 +690,6 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { | |||
| MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; | |||
| } | |||
| VectorRef args; | |||
| if (backend_->is_multi_graph_sink()) { | |||
| args.emplace_back(true); | |||
| } | |||
| args.emplace_back(Ref(inputs[1])); | |||
| args.emplace_back(Ref(inputs[2])); | |||
| args.emplace_back(Ref(inputs[3])); | |||
| @@ -761,11 +709,7 @@ void CompileGraph::AddSwitchLayer(const CNodePtr &node) { | |||
| void CompileGraph::AddReturn(const CNodePtr &node) { | |||
| VectorRef args; | |||
| if (backend_->simu_flag()) { | |||
| args.emplace_back(Ref(backend_->final_output())); | |||
| } else { | |||
| args.emplace_back(Ref(node->input(1))); | |||
| } | |||
| args.emplace_back(Ref(node->input(1))); | |||
| args.emplace_back(height_); | |||
| AddInst(Instruction::kReturn, args); | |||
| } | |||
| @@ -783,11 +727,6 @@ void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) | |||
| int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { | |||
| auto inputs = node->inputs(); | |||
| AnfNodePtr fn = inputs[0]; | |||
| if (backend_->is_multi_graph_sink() && IsValueNode<FuncGraph>(fn)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(fn); | |||
| AnfNodePtrList outs(inputs.begin() + 1, inputs.end()); | |||
| backend_->SetGraphUserInputs(func_graph, node->func_graph(), outs); | |||
| } | |||
| (void)Ref(fn); | |||
| size_t size = inputs.size(); | |||
| for (size_t i = size - 1; i > 0; i--) { | |||
| @@ -929,17 +868,6 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) { | |||
| } | |||
| FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_); | |||
| if (backend_->is_multi_graph_sink()) { | |||
| backend_->set_simu_flag(true); | |||
| MS_LOG(DEBUG) << "Start simulate"; | |||
| backend_->SimulateRun(rt, graph); | |||
| MS_LOG(DEBUG) << "Link graphs"; | |||
| insts_ = transform_->GenMultiGraphsSinkInst(graph); | |||
| rt->set_insts(insts_); | |||
| backend_->set_simu_flag(false); | |||
| MS_LOG(DEBUG) << "End start simulate"; | |||
| backend_->Link(kInvalidGraphId); | |||
| } | |||
| MS_LOG(DEBUG) << "End"; | |||
| return rt; | |||
| } | |||
| @@ -54,12 +54,10 @@ class CompileGraph { | |||
| ~CompileGraph() = default; | |||
| InstSet Run(const FuncGraphPtr &func_graph); | |||
| InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph); | |||
| bool IsCut(const AnfNodePtr &node); | |||
| void Push(const AnfNodePtr &node); | |||
| void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } | |||
| void Ret(int nargs); | |||
| void GenMultiGraphsRun(const FuncGraphPtr &graph); | |||
| int Ref(const AnfNodePtr &node); | |||
| VectorRef SplitNodes(const FuncGraphPtr &func_graph); | |||
| @@ -84,7 +82,6 @@ class CompileGraph { | |||
| int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); | |||
| int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); | |||
| int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); | |||
| void AddSinkSwitch(const CNodePtr &node); | |||
| void AddPadStack(int param_height); | |||
| void AddTailCall(const AnfNodePtr &fn, size_t size); | |||
| void AddPartial(const CNodePtr &node); | |||
| @@ -17,12 +17,9 @@ | |||
| */ | |||
| #include "vm/vm.h" | |||
| #include <algorithm> | |||
| #include "vm/vmimpl.h" | |||
| #include "vm/backend.h" | |||
| #include "vm/transform.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "utils/base_ref_extends.h" | |||
| @@ -142,33 +139,10 @@ void FinalVM::Popsp() { | |||
| } | |||
| } | |||
| void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); } | |||
| bool FinalVM::PopStatus() { | |||
| if (ret_status_.empty()) { | |||
| return false; | |||
| } | |||
| bool status = ret_status_.top(); | |||
| ret_status_.pop(); | |||
| return status; | |||
| } | |||
| void FinalVM::DoJmp(const BaseRef &jmp_orig) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| BaseRef jmp = jmp_orig; | |||
| if (backend_->simu_flag()) { | |||
| bool is_switch_call = false; | |||
| if (utils::isa<StructSimuSwitch>(jmp)) { // need to inherit from Base | |||
| MS_LOG(DEBUG) << "Start jump StructSwitch"; | |||
| auto simu_value = utils::cast<std::shared_ptr<StructSimuSwitch>>(jmp); | |||
| jmp = simu_value->fn_; | |||
| backend_->set_curr_switch(simu_value->value_); | |||
| is_switch_call = true; | |||
| } | |||
| PushStatus(is_switch_call); | |||
| } | |||
| if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base | |||
| MS_LOG(DEBUG) << "Start jump StructPartial"; | |||
| auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp); | |||
| @@ -270,13 +244,6 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; | |||
| return; | |||
| } | |||
| auto rv = Ref(-1); | |||
| if (utils::isa<AnfNodePtr>(rv) || utils::isa<VectorRef>(rv)) { | |||
| auto &c = args[0]; | |||
| cond_out_[c] = rv; | |||
| } | |||
| Pop(1); | |||
| Popsp(); | |||
| } | |||
| @@ -294,51 +261,12 @@ void FinalVM::InstReturn(const VectorRef &args) { | |||
| int height = utils::cast<int>(args[1]); | |||
| auto rv = Ref(rpos); | |||
| if (backend_->simu_flag()) { | |||
| auto c = backend_->curr_switch(); | |||
| auto status = PopStatus(); | |||
| if (status) { | |||
| auto iter = cond_out_.find(c); | |||
| if (iter != cond_out_.end()) { | |||
| rv = MergeArgs(rv, iter->second); | |||
| cond_out_.erase(iter); | |||
| } | |||
| } | |||
| if (backend_->is_switch_call()) { | |||
| backend_->SetSwitchGraph(); | |||
| } | |||
| } | |||
| Pop(height); | |||
| Push(rv); | |||
| Popp(); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstSimuPartial(const VectorRef &args) { | |||
| const size_t args_size = 2; | |||
| if (args.size() < args_size) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is " | |||
| << args.size() << "."; | |||
| return; | |||
| } | |||
| auto &node = args[0]; | |||
| if (!utils::isa<FuncGraphPtr>(node)) { | |||
| MS_LOG(ERROR) << "The type of 1st input of node must be FuncGraph"; | |||
| return; | |||
| } | |||
| auto fg = utils::cast<FuncGraphPtr>(node); | |||
| int fn_ = utils::cast<int>(args[1]); | |||
| auto fn = utils::cast<int>(Ref(fn_)); | |||
| MS_LOG(DEBUG) << "Partial argssize:" << args.size(); | |||
| std::vector<BaseRef> outs(args.size() - 2); | |||
| (void)std::transform(args.begin() + 2, args.end(), outs.begin(), | |||
| [&, this](const BaseRef &a) { return Ref(utils::cast<int>(a)); }); | |||
| Push(std::make_shared<StructPartial>(fn, VectorRef(outs), fg)); | |||
| } | |||
| void FinalVM::InstRealPartial(const VectorRef &args) { | |||
| const size_t args_size = 1; | |||
| if (args.size() < args_size) { | |||
| @@ -358,91 +286,10 @@ void FinalVM::InstRealPartial(const VectorRef &args) { | |||
| void FinalVM::InstPartial(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| if (backend_->is_multi_graph_sink()) { | |||
| InstSimuPartial(args); | |||
| } else { | |||
| InstRealPartial(args); | |||
| } | |||
| InstRealPartial(args); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::InstSimuSwitch(const VectorRef &args) { | |||
| const size_t args_size = 4; | |||
| if (args.size() != args_size) { | |||
| MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() | |||
| << "."; | |||
| return; | |||
| } | |||
| bool cond = utils::cast<bool>(args[0]); | |||
| int cond_node = utils::cast<int>(args[1]); | |||
| int vtrue = utils::cast<int>(args[2]); | |||
| int vfalse = utils::cast<int>(args[3]); | |||
| MS_LOG(DEBUG) << "Simu switch cond:" << cond; | |||
| BaseRef c = Ref(cond_node); | |||
| bool bool_value = cond; | |||
| SwitchCondStatus cond_stat = backend_->SetSimuCond(c, bool_value); | |||
| if (cond_stat == kCondAlreadyRun) { | |||
| MS_LOG(DEBUG) << "switch alreay run bool while true jmp"; | |||
| BaseRef jmp = Ref(vtrue); | |||
| if (utils::isa<StructPartial>(jmp)) { | |||
| auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp); | |||
| backend_->RecallGraphInput(new_jmp->fg_, new_jmp->args_, c); | |||
| } | |||
| cond_jmp_[c] = Ref(vfalse); | |||
| Push(static_cast<int>(cond_stat)); | |||
| Popp(); | |||
| backend_->SetSwitchActive(c, bool_value); | |||
| return; | |||
| } | |||
| if (bool_value) { | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vtrue), c)); | |||
| Pushsp(); | |||
| } else { | |||
| MergeJmpArgs(Ref(vfalse), c); | |||
| Push(std::make_shared<StructSimuSwitch>(Ref(vfalse), c)); | |||
| } | |||
| } | |||
| void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { | |||
| auto iter = cond_jmp_.find(c); | |||
| if (iter == cond_jmp_.end()) { | |||
| return; | |||
| } | |||
| auto old_jmp = utils::cast<std::shared_ptr<StructPartial>>(iter->second); | |||
| auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp); | |||
| auto &old_args = old_jmp->args_; | |||
| auto &new_args = new_jmp->args_; | |||
| for (size_t i = 0; i < new_args.size(); ++i) { | |||
| auto &old_arg = old_args[i]; | |||
| auto &new_arg = new_args[i]; | |||
| new_arg = MergeArgs(old_arg, new_arg); | |||
| } | |||
| } | |||
| BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) { | |||
| MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString(); | |||
| if (utils::isa<VectorRef>(first)) { | |||
| auto old_vec_ref = utils::cast<VectorRef>(first); | |||
| if (utils::isa<VectorRef>(second)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(second); | |||
| std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); | |||
| } else { | |||
| old_vec_ref.push_back(second); | |||
| } | |||
| return old_vec_ref; | |||
| } | |||
| if (utils::isa<VectorRef>(second)) { | |||
| auto new_vec_ref = utils::cast<VectorRef>(second); | |||
| new_vec_ref.push_back(first); | |||
| return new_vec_ref; | |||
| } | |||
| return VectorRef({first, second}); | |||
| } | |||
| void FinalVM::InstRealSwitch(const VectorRef &args) { | |||
| const size_t args_size = 3; | |||
| if (args.size() != args_size) { | |||
| @@ -472,11 +319,7 @@ void FinalVM::InstRealSwitch(const VectorRef &args) { | |||
| void FinalVM::InstSwitch(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| if (backend_->is_multi_graph_sink()) { | |||
| InstSimuSwitch(args); | |||
| } else { | |||
| InstRealSwitch(args); | |||
| } | |||
| InstRealSwitch(args); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| @@ -580,14 +423,6 @@ void FinalVM::InstExternal(const VectorRef &args) { | |||
| VectorRef tuple; | |||
| RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]); | |||
| compile::RunFuncPtr fn = run_ref.func_; | |||
| if (backend_->simu_flag()) { | |||
| MS_LOG(DEBUG) << "Simu run"; | |||
| if (args.size() == 1) { | |||
| MS_LOG(EXCEPTION) << "The number of args should be greater than 1, but got 1"; | |||
| } | |||
| auto simu_run_ref = utils::cast<RunFunctionRef>(args[1]); | |||
| fn = simu_run_ref.func_; | |||
| } | |||
| for (size_t i = 2; i < args.size(); ++i) { | |||
| auto index = utils::cast<int>(args[i]); | |||
| tuple.push_back(Ref(index)); | |||
| @@ -96,7 +96,6 @@ class FinalVM { | |||
| public: | |||
| // Create a VM with the specified instructions and backend. | |||
| explicit FinalVM(const InstSet &insts, const BackendPtr &backend); | |||
| virtual ~FinalVM() = default; | |||
| BaseRef Eval(const VectorRef &args); | |||
| @@ -104,10 +103,8 @@ class FinalVM { | |||
| void InstTailCall(const VectorRef &args); | |||
| void InstReturn(const VectorRef &args); | |||
| void InstPartial(const VectorRef &args); | |||
| void InstSimuPartial(const VectorRef &args); | |||
| void InstRealPartial(const VectorRef &args); | |||
| void InstSwitch(const VectorRef &args); | |||
| void InstSimuSwitch(const VectorRef &args); | |||
| void InstRealSwitch(const VectorRef &args); | |||
| void InstTuple(const VectorRef &args); | |||
| void InstPush(const VectorRef &args); | |||
| @@ -129,23 +126,16 @@ class FinalVM { | |||
| void Popp(); | |||
| void Pushsp(); | |||
| void Popsp(); | |||
| void PushStatus(bool is_switch_call); | |||
| bool PopStatus(); | |||
| void DoJmp(const BaseRef &jmp); | |||
| void SyncData(const py::object &args); | |||
| void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | |||
| BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); | |||
| private: | |||
| InstSet insts_; | |||
| std::deque<BaseRef> insts_stack_; | |||
| std::stack<int> retp_; | |||
| std::stack<int> retsp_; | |||
| std::stack<bool> ret_status_; | |||
| int pc_; | |||
| int sp_; | |||
| std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_jmp_; | |||
| std::unordered_map<BaseRef, BaseRef, BaseRefHash> cond_out_; | |||
| BackendPtr backend_; | |||
| const InstFunctionMap inst_function_map = { | |||
| {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, | |||