From e0c3853af2fd3b176f1d17e73e49719dea462afa Mon Sep 17 00:00:00 2001 From: zjun Date: Mon, 11 Jan 2021 16:42:49 +0800 Subject: [PATCH] Fix pynative graph have already run compiled bug in second derivate order --- .../pipeline/pynative/pynative_execute.cc | 67 ++++++++++++++----- .../pipeline/pynative/pynative_execute.h | 6 +- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 70eb833fc5..f96ee8dfa4 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1557,11 +1557,33 @@ bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) { [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); } -void PynativeExecutor::UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled) { +std::string PynativeExecutor::GetTopCell(const string &cell_id) { + if (IsTopestGraph(cell_id)) { + return cell_id; + } + std::string top_cell_id; + for (const auto &it : cell_graph_list_) { + if (IsTopestGraph(it->cell_id)) { + top_cell_id = it->cell_id; + } + if (it->cell_id == cell_id) { + break; + } + } + if (top_cell_id.empty()) { + MS_LOG(EXCEPTION) << "Get top cell null"; + } + return top_cell_id; +} + +void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != top_cell_list_.end()) { (*it)->do_vm_compiled = vm_compiled; + if ((*it)->is_topest) { + in_grad_process_ = false; + } } } @@ -1858,15 +1880,20 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; // check whether cell needed to construct grad graph - if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { - if (top_cell_list_.empty()) { - MS_LOG(EXCEPTION) << "Top cell list is empty"; - } - if (IsTopestGraph(cell_id)) { - // Clear previous step resource + if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { + // Clear previous step resource + if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { CleanPreMemoryInValueNode(); op_index_map_.clear(); top_cell_id_ = cell_id; + in_grad_process_ = true; + } + if (!in_grad_process_ && cell_op_info_stack_.empty()) { + CleanPreMemoryInValueNode(); + op_index_map_.clear(); + top_cell_id_ = GetTopCell(cell_id); + in_grad_process_ = true; + MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; } PushCurrentCellOpInfoToStack(); MS_LOG(INFO) << "NewGraph already compiled"; @@ -1920,6 +1947,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar } op_index_map_.clear(); top_cell_id_ = cell_id; + in_grad_process_ = true; auto df_builder = std::make_shared(); auto graph_info = std::make_shared(cell_id); graph_info_map_[df_builder] = graph_info; @@ -2162,7 +2190,9 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt } tmp = BasicClone(g); graph_info_map_.update(g, tmp); - ClearCnodeRes(tmp->output()); + std::unordered_set node_set; + ClearCnodeRes(tmp->output(), &node_set); + node_set.clear(); }; // First call or cell id not exist if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) { @@ -2208,17 +2238,19 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt } } -void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node) { +void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node, std::unordered_set *node_set) { MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { + MS_EXCEPTION_IF_NULL(node_set); + if (!node->isa() || (*node_set).find(node) != (*node_set).end()) { return; } + (*node_set).insert(node); auto cnode = node->cast(); cnode->clear_inputs_value(); + cnode->set_forward(nullptr, ""); for (size_t i = 0; i < cnode->size(); ++i) { auto n = cnode->input(i); - cnode->set_forward(nullptr, ""); - ClearCnodeRes(n); + ClearCnodeRes(n, node_set); } } @@ -2323,7 +2355,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) { - UpdateTopCellCompileInfo(cell_id, false); + UpdateTopCellInfo(cell_id, false); ClearDynamicTopRes(cell_id); MS_LOG(INFO) << "Gradgraph already compiled"; return; @@ -2367,7 +2399,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje ExecuteAction(resource); ClearUselessRes(df_builder, cell, cell_id); UpdateCellGraph(cell, curr_g_, cell_id, false, true); - UpdateTopCellCompileInfo(cell_id, true); + UpdateTopCellInfo(cell_id, true); resource->Clean(); } @@ -2710,9 +2742,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { const auto &cell_id = GetCellId(cell, args); - bool already_run = CheckCellGraph(cell_id); - MS_LOG(DEBUG) << "Graph have already run " << already_run << " cell id " << cell_id; - return BaseRefToPyData(already_run); + bool forward_run = CheckCellGraph(cell_id) && top_cell_id_ == cell_id; + MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " + << top_cell_id_; + return BaseRefToPyData(forward_run); } py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 4bbb057594..8f1be511fa 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -242,14 +242,15 @@ class PynativeExecutor : public std::enable_shared_from_this { bool CheckRealDynamicCell(const std::string &cell_id); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); - void ClearCnodeRes(const AnfNodePtr &node); + void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set *node_set); void UpdateCellDynamic(const std::string &cell_id); bool CheckCellChanged(const std::string &cell_id); - void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled); + void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled); void ClearResidualRes(const std::string &cell_id); void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); void NewGraphInner(const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args); + std::string GetTopCell(const string &cell_id); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, const std::string &out_id, const py::args &args); @@ -306,6 +307,7 @@ class PynativeExecutor : public std::enable_shared_from_this { size_t grad_order_{0}; std::string top_cell_id_; bool grad_flag_{false}; + bool in_grad_process_{false}; bool has_dynamic_cell_{false}; bool grad_is_running_{false}; bool need_replace_forward_{true};