Browse Source

Fix multi dynamic top cell

Signed-off-by: zjun <zhangjun0@huawei.com>
tags/v1.2.0-rc1
zjun 4 years ago
parent
commit
19bab9713d
2 changed files with 11 additions and 3 deletions
  1. +10
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +1
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 10
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -2051,6 +2051,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
top_cell_info->top_cell_index = cell_graph_list_.size();
top_cell_index_ = top_cell_info->top_cell_index;
} else {
MS_LOG(DEBUG) << "Get dynamic top cell";
auto top_cell = GetTopCell(cell_id, true);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_info->top_cell_index = top_cell->top_cell_index;
@@ -2487,7 +2488,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
const auto &params_changed = CheckGradParamsChanged(cell_id, weights, sens);
if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) {
UpdateTopCellInfo(cell_id, false);
ClearDynamicTopRes(cell_id);
ClearDynamicTopRes(cell_id, nullptr);
MS_LOG(INFO) << "Gradgraph already compiled";
return;
}
@@ -2531,10 +2532,11 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
ClearUselessRes(df_builder, cell, cell_id);
UpdateCellGraph(cell, curr_g_, cell_id, false, true);
UpdateTopCellInfo(cell_id, true);
ClearDynamicTopRes(cell_id, df_builder);
resource->Clean();
}

void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id) {
void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id, const FuncGraphPtr &df_builder) {
if (IsTopestGraph(cell_id)) {
op_index_map_.clear();
}
@@ -2544,6 +2546,12 @@ void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id) {
}
int same_top_cell_count = 0;
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
// High order should exclude
if (graph_stack_.empty() && df_builder != nullptr && (*it)->df_builder.get() != df_builder.get()) {
MS_LOG(DEBUG) << "Delete cell id " << (*it)->cell_id;
it = top_cell_list_.erase(it);
continue;
}
if ((*it)->cell_id == cell_id) {
++same_top_cell_count;
if (same_top_cell_count > 1) {


+ 1
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -270,7 +270,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
const std::string &cell_id, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr);
void ClearDynamicTopRes(const std::string &cell_id);
void ClearDynamicTopRes(const std::string &cell_id, const FuncGraphPtr &df_builder);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
std::string GetCellId(const py::object &obj, const py::args &args);


Loading…
Cancel
Save