|
|
|
@@ -640,7 +640,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { |
|
|
|
auto op_name = py::cast<std::string>(args[PY_NAME]); |
|
|
|
op_exec_info->op_name = op_name; |
|
|
|
if (grad_flag()) { |
|
|
|
op_exec_info->op_index = op_name + std::to_string(op_index_map_[op_name]); |
|
|
|
op_exec_info->op_index = op_name + "_" + std::to_string(op_index_map_[op_name]); |
|
|
|
if (!cell_op_info_stack_.empty()) { |
|
|
|
std::string &cell_op_info = cell_op_info_stack_.top(); |
|
|
|
cell_op_info += op_exec_info->op_index; |
|
|
|
@@ -1514,9 +1514,10 @@ std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { |
|
|
|
} |
|
|
|
value.emplace_back(str.substr(pre_pos)); |
|
|
|
}; |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) { |
|
|
|
return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; |
|
|
|
}); |
|
|
|
auto it = |
|
|
|
std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) { |
|
|
|
return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; |
|
|
|
}); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
std::vector<std::string> pre_cell_id; |
|
|
|
std::vector<std::string> cur_cell_id; |
|
|
|
@@ -1590,16 +1591,19 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com |
|
|
|
if (it != top_cell_list_.end()) { |
|
|
|
(*it)->do_vm_compiled = vm_compiled; |
|
|
|
(*it)->forward_already_run = false; |
|
|
|
(*it)->need_grad = true; |
|
|
|
if ((*it)->is_topest) { |
|
|
|
in_grad_process_ = false; |
|
|
|
top_cell_index_ = 0; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { |
|
|
|
return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { |
|
|
|
return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); } |
|
|
|
@@ -1611,20 +1615,21 @@ void PynativeExecutor::SubNestedGradOrder() { |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) { |
|
|
|
return value->cell_id == cell_id && (!is_grad || value->is_grad); |
|
|
|
}); |
|
|
|
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id, is_grad](const CellInfoPtr &value) { |
|
|
|
return value->cell_id == cell_id && (!is_grad || value->is_grad); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; }); |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { |
|
|
|
return value->cell_id == cell_id && value->is_real_dynamic; |
|
|
|
}); |
|
|
|
return std::any_of( |
|
|
|
cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; }); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { |
|
|
|
@@ -1891,25 +1896,23 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
// check whether cell needed to construct grad graph |
|
|
|
if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { |
|
|
|
// Clear previous step resource |
|
|
|
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { |
|
|
|
auto init_fn = [&](bool flag) { |
|
|
|
CleanPreMemoryInValueNode(); |
|
|
|
op_index_map_.clear(); |
|
|
|
in_grad_process_ = true; |
|
|
|
auto top_cell = GetTopCell(cell_id); |
|
|
|
in_bprop_process_ = false; |
|
|
|
auto top_cell = GetTopCell(cell_id, flag); |
|
|
|
MS_EXCEPTION_IF_NULL(top_cell); |
|
|
|
top_cell_id_ = top_cell->cell_id; |
|
|
|
top_cell_index_ = top_cell->top_cell_index; |
|
|
|
top_cell->forward_already_run = true; |
|
|
|
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; |
|
|
|
}; |
|
|
|
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { |
|
|
|
init_fn(false); |
|
|
|
} |
|
|
|
if (!in_grad_process_ && cell_op_info_stack_.empty()) { |
|
|
|
CleanPreMemoryInValueNode(); |
|
|
|
op_index_map_.clear(); |
|
|
|
in_grad_process_ = true; |
|
|
|
auto top_cell = GetTopCell(cell_id, true); |
|
|
|
MS_EXCEPTION_IF_NULL(top_cell); |
|
|
|
top_cell_id_ = top_cell->cell_id; |
|
|
|
top_cell->forward_already_run = true; |
|
|
|
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; |
|
|
|
init_fn(true); |
|
|
|
} |
|
|
|
PushCurrentCellOpInfoToStack(); |
|
|
|
MS_LOG(INFO) << "NewGraph already compiled"; |
|
|
|
@@ -1918,8 +1921,12 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
// Init resource for constructing forward graph and grad graph |
|
|
|
curr_g_ = std::make_shared<FuncGraph>(); |
|
|
|
ClearResidualRes(cell_id); |
|
|
|
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { |
|
|
|
MakeNewTopGraph(cell_id, args); |
|
|
|
if (graph_stack_.empty()) { |
|
|
|
if (IsBpropGraph(cell_id)) { |
|
|
|
in_bprop_process_ = true; |
|
|
|
} else { |
|
|
|
MakeNewTopGraph(cell_id, args); |
|
|
|
} |
|
|
|
} |
|
|
|
PushCurrentGraphToStack(); |
|
|
|
PushCurrentCellOpInfoToStack(); |
|
|
|
@@ -1939,6 +1946,14 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
if (!has_dynamic_cell_) { |
|
|
|
has_dynamic_cell_ = IsDynamicCell(cell); |
|
|
|
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; |
|
|
|
if (has_dynamic_cell_ && IsBpropGraph(cell_id)) { |
|
|
|
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; }); |
|
|
|
while (it != cell_graph_list_.end()) { |
|
|
|
(*it)->is_dynamic = true; |
|
|
|
++it; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1976,6 +1991,15 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar |
|
|
|
resource->results()[pipeline::kPynativeGraphId] = graph_id_++; |
|
|
|
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id); |
|
|
|
top_cell_info->forward_already_run = true; |
|
|
|
if (!IsTopestGraph(cell_id)) { |
|
|
|
top_cell_info->top_cell_index = cell_graph_list_.size(); |
|
|
|
top_cell_index_ = top_cell_info->top_cell_index; |
|
|
|
} else { |
|
|
|
auto top_cell = GetTopCell(cell_id, true); |
|
|
|
MS_EXCEPTION_IF_NULL(top_cell); |
|
|
|
top_cell_info->top_cell_index = top_cell->top_cell_index; |
|
|
|
top_cell_index_ = top_cell_info->top_cell_index; |
|
|
|
} |
|
|
|
top_cell_list_.emplace_back(top_cell_info); |
|
|
|
MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); |
|
|
|
} |
|
|
|
@@ -2086,11 +2110,22 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
resource->manager()->AddFuncGraph(curr_g_); |
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false); |
|
|
|
FuncGraphPtr newfg = nullptr; |
|
|
|
auto top_cell = GetTopCell(top_cell_id_); |
|
|
|
MS_EXCEPTION_IF_NULL(top_cell); |
|
|
|
// Cell no Change |
|
|
|
if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { |
|
|
|
MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad"; |
|
|
|
top_cell->need_grad = false; |
|
|
|
std::unordered_set<AnfNodePtr> node_set; |
|
|
|
ClearCnodeRes(curr_g_->output(), &node_set); |
|
|
|
node_set.clear(); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Need make ad grad"; |
|
|
|
if (!top_cell->need_grad) { |
|
|
|
std::unordered_set<AnfNodePtr> node_set; |
|
|
|
ClearCnodeRes(curr_g_->output(), &node_set); |
|
|
|
node_set.clear(); |
|
|
|
} |
|
|
|
newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2146,7 +2181,7 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { |
|
|
|
MS_LOG(DEBUG) << "Cell op info is empty"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) { |
|
|
|
return true; |
|
|
|
@@ -2162,22 +2197,22 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { |
|
|
|
for (auto &it : cell_graph_list_) { |
|
|
|
if (it->cell_id != cell_id) { |
|
|
|
it->is_real_dynamic = true; |
|
|
|
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { |
|
|
|
if ((*it)->cell_id != cell_id) { |
|
|
|
(*it)->is_real_dynamic = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
it->is_real_dynamic = true; |
|
|
|
(*it)->is_real_dynamic = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
auto update_in_endgraph = need_cloned && !is_grad; |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
// Bprop just save backward graph |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
(*it)->is_grad = is_grad; |
|
|
|
@@ -2196,16 +2231,23 @@ void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGr |
|
|
|
<< " cell ops info " << GetCellOpInfo(); |
|
|
|
auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); |
|
|
|
cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
if (in_bprop_process_) { |
|
|
|
cell_graph_list_.emplace_back(cell_info); |
|
|
|
} else { |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
auto update_in_endgraph = need_cloned && !is_grad; |
|
|
|
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad); |
|
|
|
if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
FuncGraphPtr tmp = g; |
|
|
|
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { |
|
|
|
MS_LOG(DEBUG) << "No need cloned"; |
|
|
|
@@ -2228,9 +2270,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); |
|
|
|
auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, tmp, cell_id, ""); |
|
|
|
cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
if (in_bprop_process_) { |
|
|
|
cell_graph_list_.emplace_back(cell_info); |
|
|
|
} else { |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
(*it)->cell_ops_info.emplace_back(GetCellOpInfo()); |
|
|
|
@@ -2240,26 +2286,26 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &it : cell_graph_list_) { |
|
|
|
if (it->cell_id != cell_id) { |
|
|
|
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { |
|
|
|
if ((*it)->cell_id != cell_id) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsFirstGradStep(cell_id)) { |
|
|
|
// no compute grad |
|
|
|
it->is_grad = is_grad; |
|
|
|
(*it)->is_grad = is_grad; |
|
|
|
} |
|
|
|
if (need_cloned) { |
|
|
|
clone_fn(); |
|
|
|
if (it->fg != nullptr) { |
|
|
|
graph_info_map_.erase(it->fg); |
|
|
|
if ((*it)->fg != nullptr) { |
|
|
|
graph_info_map_.erase((*it)->fg); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get(); |
|
|
|
it->fg = tmp; |
|
|
|
MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get(); |
|
|
|
(*it)->fg = tmp; |
|
|
|
} |
|
|
|
if (!need_cloned && !is_grad) { |
|
|
|
graph_info_map_.erase(it->fg); |
|
|
|
MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get(); |
|
|
|
it->fg = tmp; |
|
|
|
graph_info_map_.erase((*it)->fg); |
|
|
|
MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get(); |
|
|
|
(*it)->fg = tmp; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
@@ -2517,9 +2563,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
auto graph_info = std::make_shared<GraphInfo>(cell_id); |
|
|
|
graph_info_map_[df_builder] = graph_info; |
|
|
|
auto top_cell_info = std::make_shared<TopCellInfo>(false, resource, df_builder, cell_id); |
|
|
|
top_cell_info->top_cell_index = top_cell_index_; |
|
|
|
top_cell_list_.emplace_back(top_cell_info); |
|
|
|
FuncGraphPtr forward_graph = nullptr; |
|
|
|
auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (ib != cell_graph_list_.end()) { |
|
|
|
forward_graph = (*ib)->fg; |
|
|
|
@@ -2546,17 +2593,17 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const |
|
|
|
const std::string &cell_id) { |
|
|
|
std::vector<FuncGraphPtr> graph_before{}; |
|
|
|
bool index_find = false; |
|
|
|
for (const auto &it : cell_graph_list_) { |
|
|
|
if (IsBpropGraph(it->cell_id) || it->fg == nullptr) { |
|
|
|
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { |
|
|
|
if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (index_find) { |
|
|
|
graph_before.emplace_back(it->fg); |
|
|
|
graph_before.emplace_back((*it)->fg); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (it->cell_id == cell_id) { |
|
|
|
if ((*it)->cell_id == cell_id) { |
|
|
|
index_find = true; |
|
|
|
graph_before.emplace_back(it->fg); |
|
|
|
graph_before.emplace_back((*it)->fg); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2585,6 +2632,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const |
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); |
|
|
|
} |
|
|
|
} |
|
|
|
graph_before.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { |
|
|
|
@@ -2674,7 +2722,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args |
|
|
|
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, |
|
|
|
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) { |
|
|
|
FuncGraphPtr top_g = nullptr; |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
top_g = (*it)->fg; |
|
|
|
@@ -2711,7 +2759,7 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: |
|
|
|
graph_info_map_.erase(df_builder); |
|
|
|
bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id); |
|
|
|
bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id; |
|
|
|
bool is_topmost = IsTopestGraph(cell_id); |
|
|
|
if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -2720,16 +2768,32 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py: |
|
|
|
// Clear graph_info_map_ |
|
|
|
std::vector<std::string> l{}; |
|
|
|
bool index_find = false; |
|
|
|
for (auto &it : cell_graph_list_) { |
|
|
|
auto it_end = cell_graph_list_.end(); |
|
|
|
for (size_t i = 0; i < top_cell_list_.size(); ++i) { |
|
|
|
if (top_cell_list_[i]->cell_id == cell_id) { |
|
|
|
index_find = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (index_find) { |
|
|
|
it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
index_find = false; |
|
|
|
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) { |
|
|
|
if ((*it)->fg != nullptr) { |
|
|
|
std::unordered_set<AnfNodePtr> node_set; |
|
|
|
ClearCnodeRes((*it)->fg->output(), &node_set); |
|
|
|
node_set.clear(); |
|
|
|
(*it)->fg = nullptr; |
|
|
|
} |
|
|
|
if (index_find) { |
|
|
|
it->fg = nullptr; |
|
|
|
l.emplace_back(it->cell_id); |
|
|
|
l.emplace_back((*it)->cell_id); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (it->cell_id == cell_id) { |
|
|
|
if ((*it)->cell_id == cell_id) { |
|
|
|
index_find = true; |
|
|
|
it->fg = nullptr; |
|
|
|
l.emplace_back(it->cell_id); |
|
|
|
l.emplace_back((*it)->cell_id); |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &it : l) { |
|
|
|
@@ -2753,7 +2817,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & |
|
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
|
std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); |
|
|
|
MS_LOG(DEBUG) << "Key is " << key; |
|
|
|
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { |
|
|
|
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) { |
|
|
|
MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id; |
|
|
|
if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) { |
|
|
|
continue; |
|
|
|
@@ -2773,13 +2837,18 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a |
|
|
|
bool forward_run = false; |
|
|
|
if (top_cell != nullptr) { |
|
|
|
forward_run = top_cell->forward_already_run; |
|
|
|
if (forward_run) { |
|
|
|
top_cell_index_ = top_cell->top_cell_index; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " |
|
|
|
<< top_cell_id_; |
|
|
|
return BaseRefToPyData(forward_run); |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { |
|
|
|
void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, |
|
|
|
py::object *ret) { |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
MS_LOG(DEBUG) << "Run start cell id " << cell_id; |
|
|
|
bool has_sens = false; |
|
|
|
@@ -2814,17 +2883,16 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, |
|
|
|
BaseRef value = (*run)(arg_list); |
|
|
|
set_grad_runing(false); |
|
|
|
MS_LOG(DEBUG) << "Eval run end " << value.ToString(); |
|
|
|
auto out = BaseRefToPyData(value); |
|
|
|
*ret = BaseRefToPyData(value); |
|
|
|
auto do_vm_compiled = |
|
|
|
std::any_of(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; }); |
|
|
|
if (do_vm_compiled) { |
|
|
|
if (MakeBpropNestedCnode(cell, out, cell_id)) { |
|
|
|
return out; |
|
|
|
if (MakeBpropNestedCnode(cell, *ret, cell_id)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
MakeNestedCnode(cell_id, args, resource, out, has_sens); |
|
|
|
MakeNestedCnode(cell_id, args, resource, *ret, has_sens); |
|
|
|
} |
|
|
|
return out; |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { |
|
|
|
@@ -2883,7 +2951,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg |
|
|
|
void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, |
|
|
|
std::vector<AnfNodePtr> *inputs) { |
|
|
|
FuncGraphPtr forward_graph = nullptr; |
|
|
|
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
|
if (ic != cell_graph_list_.end()) { |
|
|
|
forward_graph = (*ic)->fg; |
|
|
|
@@ -2950,20 +3018,26 @@ void PynativeExecutor::ClearRes() { |
|
|
|
graph_id_ = 0; |
|
|
|
grad_order_ = 0; |
|
|
|
grad_flag_ = false; |
|
|
|
in_grad_process_ = false; |
|
|
|
in_bprop_process_ = false; |
|
|
|
has_dynamic_cell_ = false; |
|
|
|
grad_is_running_ = false; |
|
|
|
need_replace_forward_ = true; |
|
|
|
curr_g_ = nullptr; |
|
|
|
|
|
|
|
top_cell_id_.clear(); |
|
|
|
graph_info_map_.clear(); |
|
|
|
replace_weights_map_.clear(); |
|
|
|
cell_graph_list_.clear(); |
|
|
|
top_cell_list_.clear(); |
|
|
|
cell_input_args_.clear(); |
|
|
|
op_index_map_.clear(); |
|
|
|
cell_op_index_with_tensor_id_.clear(); |
|
|
|
cell_tensor_id_with_tensor_.clear(); |
|
|
|
prim_abs_list_.clear(); |
|
|
|
all_value_node_tensors_.clear(); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_stack_); |
|
|
|
std::stack<std::string>().swap(cell_op_info_stack_); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { |
|
|
|
@@ -2984,6 +3058,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c |
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { |
|
|
|
py::object ret; |
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::RunInner, cell, args, phase, &ret); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Sync() { |
|
|
|
if (session == nullptr) { |
|
|
|
MS_EXCEPTION(NotExistsError) << "No session has been created!"; |
|
|
|
|