diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ad3e6d3207..90df5df537 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -640,7 +640,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { auto op_name = py::cast(args[PY_NAME]); op_exec_info->op_name = op_name; if (grad_flag()) { - op_exec_info->op_index = op_name + std::to_string(op_index_map_[op_name]); + op_exec_info->op_index = op_name + "_" + std::to_string(op_index_map_[op_name]); if (!cell_op_info_stack_.empty()) { std::string &cell_op_info = cell_op_info_stack_.top(); cell_op_info += op_exec_info->op_index; @@ -1514,9 +1514,10 @@ std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { } value.emplace_back(str.substr(pre_pos)); }; - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) { - return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; - }); + auto it = + std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) { + return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; + }); if (it != cell_graph_list_.end()) { std::vector pre_cell_id; std::vector cur_cell_id; @@ -1590,16 +1591,19 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com if (it != top_cell_list_.end()) { (*it)->do_vm_compiled = vm_compiled; (*it)->forward_already_run = false; + (*it)->need_grad = true; if ((*it)->is_topest) { in_grad_process_ = false; + top_cell_index_ = 0; } } } bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { - return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; - }); + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { + return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; + }); } bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); } @@ -1611,20 +1615,21 @@ void PynativeExecutor::SubNestedGradOrder() { } bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) { - return value->cell_id == cell_id && (!is_grad || value->is_grad); - }); + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id, is_grad](const CellInfoPtr &value) { + return value->cell_id == cell_id && (!is_grad || value->is_grad); + }); } bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), + return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; }); } bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { - return value->cell_id == cell_id && value->is_real_dynamic; - }); + return std::any_of( + cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; }); } void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { @@ -1891,25 +1896,23 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg // check whether cell needed to construct grad graph if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { // Clear previous step resource - if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { + auto init_fn = [&](bool flag) { CleanPreMemoryInValueNode(); op_index_map_.clear(); in_grad_process_ = true; - auto top_cell = GetTopCell(cell_id); + in_bprop_process_ = false; + auto top_cell = GetTopCell(cell_id, flag); MS_EXCEPTION_IF_NULL(top_cell); top_cell_id_ = top_cell->cell_id; + top_cell_index_ = top_cell->top_cell_index; top_cell->forward_already_run = true; MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; + }; + if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { + init_fn(false); } if (!in_grad_process_ && cell_op_info_stack_.empty()) { - CleanPreMemoryInValueNode(); - op_index_map_.clear(); - in_grad_process_ = true; - auto top_cell = GetTopCell(cell_id, true); - MS_EXCEPTION_IF_NULL(top_cell); - top_cell_id_ = top_cell->cell_id; - top_cell->forward_already_run = true; - MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; + init_fn(true); } PushCurrentCellOpInfoToStack(); MS_LOG(INFO) << "NewGraph already compiled"; @@ -1918,8 +1921,12 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg // Init resource for constructing forward graph and grad graph curr_g_ = std::make_shared(); ClearResidualRes(cell_id); - if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { - MakeNewTopGraph(cell_id, args); + if (graph_stack_.empty()) { + if (IsBpropGraph(cell_id)) { + in_bprop_process_ = true; + } else { + MakeNewTopGraph(cell_id, args); + } } PushCurrentGraphToStack(); PushCurrentCellOpInfoToStack(); @@ -1939,6 +1946,14 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (!has_dynamic_cell_) { has_dynamic_cell_ = IsDynamicCell(cell); MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; + if (has_dynamic_cell_ && IsBpropGraph(cell_id)) { + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), + [this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; }); + while (it != cell_graph_list_.end()) { + (*it)->is_dynamic = true; + ++it; + } + } } } @@ -1976,6 +1991,15 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar resource->results()[pipeline::kPynativeGraphId] = graph_id_++; auto top_cell_info = std::make_shared(true, resource, df_builder, cell_id); top_cell_info->forward_already_run = true; + if (!IsTopestGraph(cell_id)) { + top_cell_info->top_cell_index = cell_graph_list_.size(); + top_cell_index_ = top_cell_info->top_cell_index; + } else { + auto top_cell = GetTopCell(cell_id, true); + MS_EXCEPTION_IF_NULL(top_cell); + top_cell_info->top_cell_index = top_cell->top_cell_index; + top_cell_index_ = top_cell_info->top_cell_index; + } top_cell_list_.emplace_back(top_cell_info); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } @@ -2086,11 +2110,22 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string resource->manager()->AddFuncGraph(curr_g_); UpdateCellGraph(cell, curr_g_, cell_id, true, false); FuncGraphPtr newfg = nullptr; + auto top_cell = GetTopCell(top_cell_id_); + MS_EXCEPTION_IF_NULL(top_cell); // Cell no Change if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad"; + top_cell->need_grad = false; + std::unordered_set node_set; + ClearCnodeRes(curr_g_->output(), &node_set); + node_set.clear(); } else { MS_LOG(DEBUG) << "Need make ad grad"; + if (!top_cell->need_grad) { + std::unordered_set node_set; + ClearCnodeRes(curr_g_->output(), &node_set); + node_set.clear(); + } newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); } @@ -2146,7 +2181,7 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { MS_LOG(DEBUG) << "Cell op info is empty"; return true; } - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) { return true; @@ -2162,22 +2197,22 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { } void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { - for (auto &it : cell_graph_list_) { - if (it->cell_id != cell_id) { - it->is_real_dynamic = true; + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if ((*it)->cell_id != cell_id) { + (*it)->is_real_dynamic = true; continue; } - it->is_real_dynamic = true; + (*it)->is_real_dynamic = true; break; } } -void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, +bool PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad) { auto update_in_endgraph = need_cloned && !is_grad; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { // Bprop just save backward graph - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { (*it)->is_grad = is_grad; @@ -2196,16 +2231,23 @@ void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGr << " cell ops info " << GetCellOpInfo(); auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); - cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + if (in_bprop_process_) { + cell_graph_list_.emplace_back(cell_info); + } else { + cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); + } } - return; + return true; } + return false; } void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad) { auto update_in_endgraph = need_cloned && !is_grad; - UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad); + if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) { + return; + } FuncGraphPtr tmp = g; if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { MS_LOG(DEBUG) << "No need cloned"; @@ -2228,9 +2270,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); auto cell_info = std::make_shared(true, has_dynamic_cell_, tmp, cell_id, ""); cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); - cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + if (in_bprop_process_) { + cell_graph_list_.emplace_back(cell_info); + } else { + cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); + } } else { - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); @@ -2240,26 +2286,26 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt return; } - for (auto &it : cell_graph_list_) { - if (it->cell_id != cell_id) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if ((*it)->cell_id != cell_id) { continue; } if (IsFirstGradStep(cell_id)) { // no compute grad - it->is_grad = is_grad; + (*it)->is_grad = is_grad; } if (need_cloned) { clone_fn(); - if (it->fg != nullptr) { - graph_info_map_.erase(it->fg); + if ((*it)->fg != nullptr) { + graph_info_map_.erase((*it)->fg); } - MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get(); - it->fg = tmp; + MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get(); + (*it)->fg = tmp; } if (!need_cloned && !is_grad) { - graph_info_map_.erase(it->fg); - MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get(); - it->fg = tmp; + graph_info_map_.erase((*it)->fg); + MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get(); + (*it)->fg = tmp; } break; } @@ -2517,9 +2563,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args auto graph_info = std::make_shared(cell_id); graph_info_map_[df_builder] = graph_info; auto top_cell_info = std::make_shared(false, resource, df_builder, cell_id); + top_cell_info->top_cell_index = top_cell_index_; top_cell_list_.emplace_back(top_cell_info); FuncGraphPtr forward_graph = nullptr; - auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ib != cell_graph_list_.end()) { forward_graph = (*ib)->fg; @@ -2546,17 +2593,17 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const const std::string &cell_id) { std::vector graph_before{}; bool index_find = false; - for (const auto &it : cell_graph_list_) { - if (IsBpropGraph(it->cell_id) || it->fg == nullptr) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { + if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) { continue; } if (index_find) { - graph_before.emplace_back(it->fg); + graph_before.emplace_back((*it)->fg); continue; } - if (it->cell_id == cell_id) { + if ((*it)->cell_id == cell_id) { index_find = true; - graph_before.emplace_back(it->fg); + graph_before.emplace_back((*it)->fg); } } @@ -2585,6 +2632,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); } } + graph_before.clear(); } void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { @@ -2674,7 +2722,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, size_t arg_size, const std::string &cell_id) { FuncGraphPtr top_g = nullptr; - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { top_g = (*it)->fg; @@ -2711,7 +2759,7 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: graph_info_map_.erase(df_builder); bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id); - bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id; + bool is_topmost = IsTopestGraph(cell_id); if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { return; } @@ -2720,16 +2768,32 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: // Clear graph_info_map_ std::vector l{}; bool index_find = false; - for (auto &it : cell_graph_list_) { + auto it_end = cell_graph_list_.end(); + for (size_t i = 0; i < top_cell_list_.size(); ++i) { + if (top_cell_list_[i]->cell_id == cell_id) { + index_find = true; + continue; + } + if (index_find) { + it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index; + break; + } + } + index_find = false; + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) { + if ((*it)->fg != nullptr) { + std::unordered_set node_set; + ClearCnodeRes((*it)->fg->output(), &node_set); + node_set.clear(); + (*it)->fg = nullptr; + } if (index_find) { - it->fg = nullptr; - l.emplace_back(it->cell_id); + l.emplace_back((*it)->cell_id); continue; } - if (it->cell_id == cell_id) { + if ((*it)->cell_id == cell_id) { index_find = true; - it->fg = nullptr; - l.emplace_back(it->cell_id); + l.emplace_back((*it)->cell_id); } } for (const auto &it : l) { @@ -2753,7 +2817,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & const auto &cell_id = GetCellId(cell, args); std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); MS_LOG(DEBUG) << "Key is " << key; - for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { + for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id; if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) { continue; @@ -2773,13 +2837,18 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a bool forward_run = false; if (top_cell != nullptr) { forward_run = top_cell->forward_already_run; + if (forward_run) { + top_cell_index_ = top_cell->top_cell_index; + } } MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " << top_cell_id_; return BaseRefToPyData(forward_run); } -py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { +void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, + py::object *ret) { + MS_EXCEPTION_IF_NULL(ret); auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "Run start cell id " << cell_id; bool has_sens = false; @@ -2814,17 +2883,16 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, BaseRef value = (*run)(arg_list); set_grad_runing(false); MS_LOG(DEBUG) << "Eval run end " << value.ToString(); - auto out = BaseRefToPyData(value); + *ret = BaseRefToPyData(value); auto do_vm_compiled = std::any_of(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; }); if (do_vm_compiled) { - if (MakeBpropNestedCnode(cell, out, cell_id)) { - return out; + if (MakeBpropNestedCnode(cell, *ret, cell_id)) { + return; } - MakeNestedCnode(cell_id, args, resource, out, has_sens); + MakeNestedCnode(cell_id, args, resource, *ret, has_sens); } - return out; } bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { @@ -2883,7 +2951,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs) { FuncGraphPtr forward_graph = nullptr; - auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ic != cell_graph_list_.end()) { forward_graph = (*ic)->fg; @@ -2950,20 +3018,26 @@ void PynativeExecutor::ClearRes() { graph_id_ = 0; grad_order_ = 0; grad_flag_ = false; + in_grad_process_ = false; + in_bprop_process_ = false; has_dynamic_cell_ = false; grad_is_running_ = false; need_replace_forward_ = true; curr_g_ = nullptr; + top_cell_id_.clear(); graph_info_map_.clear(); replace_weights_map_.clear(); cell_graph_list_.clear(); top_cell_list_.clear(); + cell_input_args_.clear(); op_index_map_.clear(); cell_op_index_with_tensor_id_.clear(); cell_tensor_id_with_tensor_.clear(); prim_abs_list_.clear(); + all_value_node_tensors_.clear(); std::stack().swap(graph_stack_); + std::stack().swap(cell_op_info_stack_); } void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { @@ -2984,6 +3058,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); } +py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { + py::object ret; + PynativeExecutorTry(this, &PynativeExecutor::RunInner, cell, args, phase, &ret); + return ret; +} + void PynativeExecutor::Sync() { if (session == nullptr) { MS_EXCEPTION(NotExistsError) << "No session has been created!"; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index b087c8a31f..912ee64df7 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -96,9 +96,11 @@ class TopCellInfo { TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) : is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {} + bool need_grad{true}; bool is_topest{false}; bool do_vm_compiled{false}; bool forward_already_run{false}; + size_t top_cell_index{0}; ResourcePtr resource{nullptr}; FuncGraphPtr df_builder{nullptr}; FuncGraphPtr bg{nullptr}; // Backward graph @@ -134,6 +136,7 @@ class PynativeExecutor : public std::enable_shared_from_this { OpExecInfoPtr GenerateOpExecInfo(const py::args &args); void NewGraph(const py::object &cell, const py::args &args); py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); + void RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, py::object *ret); py::object CheckGraph(const py::object &cell, const py::args &args); py::object CheckAlreadyRun(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); @@ -241,7 +244,7 @@ class PynativeExecutor : public std::enable_shared_from_this { bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); bool CheckDynamicCell(const std::string &cell_id); bool CheckRealDynamicCell(const std::string &cell_id); - void UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, + bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); @@ -308,8 +311,10 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::mutex instance_lock_; static int64_t graph_id_; size_t grad_order_{0}; + size_t top_cell_index_{0}; std::string top_cell_id_; bool grad_flag_{false}; + bool in_bprop_process_{false}; bool in_grad_process_{false}; bool has_dynamic_cell_{false}; bool grad_is_running_{false};