diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index f96ee8dfa4..ec123bfc46 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1341,7 +1341,7 @@ py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_poli break; } case kMsBackendMsPrior: { - // use Ms fisrt,use others when ms failed + // use Ms first,use others when ms failed MS_LOG(INFO) << "RunOp use Ms first backend"; result = RunOpInMs(op_exec_info, status); if (*status != PYNATIVE_SUCCESS) { @@ -1557,23 +1557,28 @@ bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) { [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); } -std::string PynativeExecutor::GetTopCell(const string &cell_id) { - if (IsTopestGraph(cell_id)) { - return cell_id; - } - std::string top_cell_id; - for (const auto &it : cell_graph_list_) { - if (IsTopestGraph(it->cell_id)) { - top_cell_id = it->cell_id; +TopCellInfoPtr PynativeExecutor::GetTopCell(const string &cell_id, bool find_nearest) { + auto find_top_cell = [&](const string &cell_id) -> TopCellInfoPtr { + auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &top_cell) { + return cell_id == top_cell->cell_id && top_cell->is_topest; + }); + if (iter != top_cell_list_.end()) { + return *iter; } - if (it->cell_id == cell_id) { - break; + return nullptr; + }; + TopCellInfoPtr top_cell = find_top_cell(cell_id); + // find nearest top cell + if (top_cell == nullptr && find_nearest) { + for (const auto &cell_info : cell_graph_list_) { + MS_EXCEPTION_IF_NULL(cell_info); + top_cell = find_top_cell(cell_info->cell_id); + if (cell_id == cell_info->cell_id) { + break; + } } } - if (top_cell_id.empty()) { - MS_LOG(EXCEPTION) << "Get top cell null"; - } - return top_cell_id; + return top_cell; } void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { @@ -1581,6 +1586,7 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != top_cell_list_.end()) { (*it)->do_vm_compiled = vm_compiled; + (*it)->forward_already_run = false; if ((*it)->is_topest) { in_grad_process_ = false; } @@ -1704,7 +1710,7 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptrcell_id; + top_cell->forward_already_run = true; + MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; } if (!in_grad_process_ && cell_op_info_stack_.empty()) { CleanPreMemoryInValueNode(); op_index_map_.clear(); - top_cell_id_ = GetTopCell(cell_id); in_grad_process_ = true; + auto top_cell = GetTopCell(cell_id, true); + MS_EXCEPTION_IF_NULL(top_cell); + top_cell_id_ = top_cell->cell_id; + top_cell->forward_already_run = true; MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; } PushCurrentCellOpInfoToStack(); @@ -1948,12 +1961,18 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar op_index_map_.clear(); top_cell_id_ = cell_id; in_grad_process_ = true; + // update forward already run flag with previous top cell + auto pre_top_cell = GetTopCell(cell_id); + if (pre_top_cell != nullptr) { + pre_top_cell->forward_already_run = true; + } auto df_builder = std::make_shared(); auto graph_info = std::make_shared(cell_id); graph_info_map_[df_builder] = graph_info; auto resource = std::make_shared(); resource->results()[pipeline::kPynativeGraphId] = graph_id_++; auto top_cell_info = std::make_shared(true, resource, df_builder, cell_id); + top_cell_info->forward_already_run = true; top_cell_list_.emplace_back(top_cell_info); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } @@ -2742,7 +2761,11 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { const auto &cell_id = GetCellId(cell, args); - bool forward_run = CheckCellGraph(cell_id) && top_cell_id_ == cell_id; + auto top_cell = GetTopCell(cell_id); + bool forward_run = false; + if (top_cell != nullptr) { + forward_run = top_cell->forward_already_run; + } MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " << top_cell_id_; return BaseRefToPyData(forward_run); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 8f1be511fa..8507d1c8b6 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -62,7 +62,7 @@ void ClearPyNativeSession(); struct GraphInfo { std::string cell_id; AnfNodePtr output; - OrderedMap params; // hold input parameters and cell weigths + OrderedMap params; // hold input parameters and cell weights std::unordered_map>> node_map; std::vector objects; GraphInfo() = default; @@ -98,6 +98,7 @@ class TopCellInfo { bool is_topest{false}; bool do_vm_compiled{false}; + bool forward_already_run{false}; ResourcePtr resource{nullptr}; FuncGraphPtr df_builder{nullptr}; FuncGraphPtr bg{nullptr}; // Backward graph @@ -250,7 +251,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); void NewGraphInner(const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args); - std::string GetTopCell(const string &cell_id); + TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, const std::string &out_id, const py::args &args); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 9191eed8e5..e6147b135a 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -321,10 +321,12 @@ class Cell(Cell_): for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") + origin_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) _pynative_exec.new_graph(self, *inputs, **kwargs) for cell in self.cells(): + origin_grad.append(cell.requires_grad) cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) @@ -348,6 +350,8 @@ class Cell(Cell_): output = output.data if self.requires_grad is True: _pynative_exec.end_graph(self, output, *inputs, **kwargs) + for i, cell in enumerate(self.cells()): + cell.set_grad(origin_grad[i]) return output def _add_attr(self, name, value):