From 073177ac5520c19b2c976e0cb3c71aa9f9dc6ed2 Mon Sep 17 00:00:00 2001 From: zjun Date: Thu, 21 Jan 2021 09:53:47 +0800 Subject: [PATCH] Fix dynamitc top cell Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 153 ++++++++++++------ .../pipeline/pynative/pynative_execute.h | 3 + 2 files changed, 108 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index e2e522a554..2058719b8e 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1505,9 +1505,10 @@ std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { } value.emplace_back(str.substr(pre_pos)); }; - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) { - return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; - }); + auto it = + std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) { + return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; + }); if (it != cell_graph_list_.end()) { std::vector pre_cell_id; std::vector cur_cell_id; @@ -1581,6 +1582,7 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com if (it != top_cell_list_.end()) { (*it)->do_vm_compiled = vm_compiled; (*it)->forward_already_run = false; + (*it)->need_grad = true; if ((*it)->is_topest) { in_grad_process_ = false; } @@ -1588,9 +1590,10 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com } bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { - return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; - }); + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { + return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; + }); } bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); } @@ -1602,20 +1605,21 @@ void PynativeExecutor::SubNestedGradOrder() { } bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) { - return value->cell_id == cell_id && (!is_grad || value->is_grad); - }); + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id, is_grad](const CellInfoPtr &value) { + return value->cell_id == cell_id && (!is_grad || value->is_grad); + }); } bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; }); } bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { - return value->cell_id == cell_id && value->is_real_dynamic; - }); + return std::any_of( + cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; }); } void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { @@ -1889,6 +1893,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto top_cell = GetTopCell(cell_id); MS_EXCEPTION_IF_NULL(top_cell); top_cell_id_ = top_cell->cell_id; + top_cell_index_ = top_cell->top_cell_index; top_cell->forward_already_run = true; MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; } @@ -1899,6 +1904,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto top_cell = GetTopCell(cell_id, true); MS_EXCEPTION_IF_NULL(top_cell); top_cell_id_ = top_cell->cell_id; + top_cell_index_ = top_cell->top_cell_index; top_cell->forward_already_run = true; MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; } @@ -1930,6 +1936,14 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (!has_dynamic_cell_) { has_dynamic_cell_ = IsDynamicCell(cell); MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; + if (has_dynamic_cell_ && IsBpropGraph(cell_id)) { + auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + [this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; }); + while (it != cell_graph_list_.end()) { + (*it)->is_dynamic = true; + ++it; + } + } } } @@ -1967,6 +1981,15 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar resource->results()[pipeline::kPynativeGraphId] = graph_id_++; auto top_cell_info = std::make_shared(true, resource, df_builder, cell_id); top_cell_info->forward_already_run = true; + if (!IsTopestGraph(cell_id)) { + top_cell_info->top_cell_index = cell_graph_list_.size(); + top_cell_index_ = top_cell_info->top_cell_index; + } else { + auto top_cell = GetTopCell(cell_id, true); + MS_EXCEPTION_IF_NULL(top_cell); + top_cell_info->top_cell_index = top_cell->top_cell_index; + top_cell_index_ = top_cell_info->top_cell_index; + } top_cell_list_.emplace_back(top_cell_info); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } @@ -2077,11 +2100,22 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string resource->manager()->AddFuncGraph(curr_g_); UpdateCellGraph(cell, curr_g_, cell_id, true, false); FuncGraphPtr newfg = nullptr; + auto top_cell = GetTopCell(top_cell_id_); + MS_EXCEPTION_IF_NULL(top_cell); // Cell no Change if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad"; + top_cell->need_grad = false; + std::unordered_set node_set; + ClearCnodeRes(curr_g_->output(), &node_set); + node_set.clear(); } else { MS_LOG(DEBUG) << "Need make ad grad"; + if (!top_cell->need_grad) { + std::unordered_set node_set; + ClearCnodeRes(curr_g_->output(), &node_set); + node_set.clear(); + } newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); } @@ -2137,7 +2171,7 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { MS_LOG(DEBUG) << "Cell op info is empty"; return true; } - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) { return true; @@ -2153,12 +2187,12 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { } void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { - for (auto &it : cell_graph_list_) { - if (it->cell_id != cell_id) { - it->is_real_dynamic = true; + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if ((*it)->cell_id != cell_id) { + (*it)->is_real_dynamic = true; continue; } - it->is_real_dynamic = true; + (*it)->is_real_dynamic = true; break; } } @@ -2168,7 +2202,7 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt auto update_in_endgraph = need_cloned && !is_grad; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { // Bprop just save backward graph - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { (*it)->is_grad = is_grad; @@ -2187,7 +2221,7 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt << " cell ops info " << GetCellOpInfo(); auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); - cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); } return; } @@ -2214,9 +2248,9 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); auto cell_info = std::make_shared(true, has_dynamic_cell_, tmp, cell_id, ""); cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); - cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); } else { - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); @@ -2226,26 +2260,26 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt return; } - for (auto &it : cell_graph_list_) { - if (it->cell_id != cell_id) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if ((*it)->cell_id != cell_id) { continue; } if (IsFirstGradStep(cell_id)) { // no compute grad - it->is_grad = is_grad; + (*it)->is_grad = is_grad; } if (need_cloned) { clone_fn(); - if (it->fg != nullptr) { - graph_info_map_.erase(it->fg); + if ((*it)->fg != nullptr) { + graph_info_map_.erase((*it)->fg); } - MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get(); - it->fg = tmp; + MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get(); + (*it)->fg = tmp; } if (!need_cloned && !is_grad) { - graph_info_map_.erase(it->fg); - MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get(); - it->fg = tmp; + graph_info_map_.erase((*it)->fg); + MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get(); + (*it)->fg = tmp; } break; } @@ -2502,9 +2536,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args auto graph_info = std::make_shared(cell_id); graph_info_map_[df_builder] = graph_info; auto top_cell_info = std::make_shared(false, resource, df_builder, cell_id); + top_cell_info->top_cell_index = top_cell_index_; top_cell_list_.emplace_back(top_cell_info); FuncGraphPtr forward_graph = nullptr; - auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ib != cell_graph_list_.end()) { forward_graph = (*ib)->fg; @@ -2531,17 +2566,17 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const const std::string &cell_id) { std::vector graph_before{}; bool index_find = false; - for (const auto &it : cell_graph_list_) { - if (IsBpropGraph(it->cell_id) || it->fg == nullptr) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) { continue; } if (index_find) { - graph_before.emplace_back(it->fg); + graph_before.emplace_back((*it)->fg); continue; } - if (it->cell_id == cell_id) { + if ((*it)->cell_id == cell_id) { index_find = true; - graph_before.emplace_back(it->fg); + graph_before.emplace_back((*it)->fg); } } @@ -2570,6 +2605,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); } } + graph_before.clear(); } void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { @@ -2659,7 +2695,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, size_t arg_size, const std::string &cell_id) { FuncGraphPtr top_g = nullptr; - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { top_g = (*it)->fg; @@ -2696,7 +2732,7 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: graph_info_map_.erase(df_builder); bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id); - bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id; + bool is_topmost = IsTopestGraph(cell_id); if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { return; } @@ -2705,16 +2741,34 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: // Clear graph_info_map_ std::vector l{}; bool index_find = false; - for (auto &it : cell_graph_list_) { + auto it_end = cell_graph_list_.end(); + for (size_t i = 0; i < top_cell_list_.size(); ++i) { + if (top_cell_list_[i]->cell_id == cell_id) { + index_find = true; + continue; + } + if (index_find) { + it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index; + break; + } + } + index_find = false; + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) { + // Clear cnode memory + if ((*it)->fg != nullptr) { + std::unordered_set node_set; + ClearCnodeRes((*it)->fg->output(), &node_set); + node_set.clear(); + (*it)->fg = nullptr; + } + if (index_find) { - it->fg = nullptr; - l.emplace_back(it->cell_id); + l.emplace_back((*it)->cell_id); continue; } - if (it->cell_id == cell_id) { + if ((*it)->cell_id == cell_id) { index_find = true; - it->fg = nullptr; - l.emplace_back(it->cell_id); + l.emplace_back((*it)->cell_id); } } for (const auto &it : l) { @@ -2738,7 +2792,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & const auto &cell_id = GetCellId(cell, args); std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); MS_LOG(DEBUG) << "Key is " << key; - for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id; if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) { continue; @@ -2758,6 +2812,9 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a bool forward_run = false; if (top_cell != nullptr) { forward_run = top_cell->forward_already_run; + if (forward_run) { + top_cell_index_ = top_cell->top_cell_index; + } } MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " << top_cell_id_; @@ -2868,7 +2925,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs) { FuncGraphPtr forward_graph = nullptr; - auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ic != cell_graph_list_.end()) { forward_graph = (*ic)->fg; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 600ee54993..e5ef9e13be 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -96,9 +96,11 @@ class TopCellInfo { TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) : is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {} + bool need_grad{true}; bool is_topest{false}; bool do_vm_compiled{false}; bool forward_already_run{false}; + size_t top_cell_index{0}; ResourcePtr resource{nullptr}; FuncGraphPtr df_builder{nullptr}; FuncGraphPtr bg{nullptr}; // Backward graph @@ -306,6 +308,7 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::mutex instance_lock_; static int64_t graph_id_; size_t grad_order_{0}; + size_t top_cell_index_{0}; std::string top_cell_id_; bool grad_flag_{false}; bool in_grad_process_{false};