|
|
@@ -1557,11 +1557,33 @@ bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) { |
|
|
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); |
|
|
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled) { |
|
|
|
|
|
|
|
|
std::string PynativeExecutor::GetTopCell(const string &cell_id) { |
|
|
|
|
|
if (IsTopestGraph(cell_id)) { |
|
|
|
|
|
return cell_id; |
|
|
|
|
|
} |
|
|
|
|
|
std::string top_cell_id; |
|
|
|
|
|
for (const auto &it : cell_graph_list_) { |
|
|
|
|
|
if (IsTopestGraph(it->cell_id)) { |
|
|
|
|
|
top_cell_id = it->cell_id; |
|
|
|
|
|
} |
|
|
|
|
|
if (it->cell_id == cell_id) { |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (top_cell_id.empty()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get top cell null"; |
|
|
|
|
|
} |
|
|
|
|
|
return top_cell_id; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) { |
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); |
|
|
if (it != top_cell_list_.end()) { |
|
|
if (it != top_cell_list_.end()) { |
|
|
(*it)->do_vm_compiled = vm_compiled; |
|
|
(*it)->do_vm_compiled = vm_compiled; |
|
|
|
|
|
if ((*it)->is_topest) { |
|
|
|
|
|
in_grad_process_ = false; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -1858,15 +1880,20 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
auto cell_id = GetCellId(cell, args); |
|
|
auto cell_id = GetCellId(cell, args); |
|
|
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; |
|
|
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; |
|
|
// check whether cell needed to construct grad graph |
|
|
// check whether cell needed to construct grad graph |
|
|
if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { |
|
|
|
|
|
if (top_cell_list_.empty()) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Top cell list is empty"; |
|
|
|
|
|
} |
|
|
|
|
|
if (IsTopestGraph(cell_id)) { |
|
|
|
|
|
// Clear previous step resource |
|
|
|
|
|
|
|
|
if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { |
|
|
|
|
|
// Clear previous step resource |
|
|
|
|
|
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) { |
|
|
CleanPreMemoryInValueNode(); |
|
|
CleanPreMemoryInValueNode(); |
|
|
op_index_map_.clear(); |
|
|
op_index_map_.clear(); |
|
|
top_cell_id_ = cell_id; |
|
|
top_cell_id_ = cell_id; |
|
|
|
|
|
in_grad_process_ = true; |
|
|
|
|
|
} |
|
|
|
|
|
if (!in_grad_process_ && cell_op_info_stack_.empty()) { |
|
|
|
|
|
CleanPreMemoryInValueNode(); |
|
|
|
|
|
op_index_map_.clear(); |
|
|
|
|
|
top_cell_id_ = GetTopCell(cell_id); |
|
|
|
|
|
in_grad_process_ = true; |
|
|
|
|
|
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_; |
|
|
} |
|
|
} |
|
|
PushCurrentCellOpInfoToStack(); |
|
|
PushCurrentCellOpInfoToStack(); |
|
|
MS_LOG(INFO) << "NewGraph already compiled"; |
|
|
MS_LOG(INFO) << "NewGraph already compiled"; |
|
|
@@ -1920,6 +1947,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar |
|
|
} |
|
|
} |
|
|
op_index_map_.clear(); |
|
|
op_index_map_.clear(); |
|
|
top_cell_id_ = cell_id; |
|
|
top_cell_id_ = cell_id; |
|
|
|
|
|
in_grad_process_ = true; |
|
|
auto df_builder = std::make_shared<FuncGraph>(); |
|
|
auto df_builder = std::make_shared<FuncGraph>(); |
|
|
auto graph_info = std::make_shared<GraphInfo>(cell_id); |
|
|
auto graph_info = std::make_shared<GraphInfo>(cell_id); |
|
|
graph_info_map_[df_builder] = graph_info; |
|
|
graph_info_map_[df_builder] = graph_info; |
|
|
@@ -2162,7 +2190,9 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
} |
|
|
} |
|
|
tmp = BasicClone(g); |
|
|
tmp = BasicClone(g); |
|
|
graph_info_map_.update(g, tmp); |
|
|
graph_info_map_.update(g, tmp); |
|
|
ClearCnodeRes(tmp->output()); |
|
|
|
|
|
|
|
|
std::unordered_set<AnfNodePtr> node_set; |
|
|
|
|
|
ClearCnodeRes(tmp->output(), &node_set); |
|
|
|
|
|
node_set.clear(); |
|
|
}; |
|
|
}; |
|
|
// First call or cell id not exist |
|
|
// First call or cell id not exist |
|
|
if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) { |
|
|
if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) { |
|
|
@@ -2208,17 +2238,19 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
|
void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node, std::unordered_set<AnfNodePtr> *node_set) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
if (!node->isa<CNode>()) { |
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_set); |
|
|
|
|
|
if (!node->isa<CNode>() || (*node_set).find(node) != (*node_set).end()) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
(*node_set).insert(node); |
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
cnode->clear_inputs_value(); |
|
|
cnode->clear_inputs_value(); |
|
|
|
|
|
cnode->set_forward(nullptr, ""); |
|
|
for (size_t i = 0; i < cnode->size(); ++i) { |
|
|
for (size_t i = 0; i < cnode->size(); ++i) { |
|
|
auto n = cnode->input(i); |
|
|
auto n = cnode->input(i); |
|
|
cnode->set_forward(nullptr, ""); |
|
|
|
|
|
ClearCnodeRes(n); |
|
|
|
|
|
|
|
|
ClearCnodeRes(n, node_set); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -2323,7 +2355,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; |
|
|
MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; |
|
|
const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); |
|
|
const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); |
|
|
if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) { |
|
|
if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) { |
|
|
UpdateTopCellCompileInfo(cell_id, false); |
|
|
|
|
|
|
|
|
UpdateTopCellInfo(cell_id, false); |
|
|
ClearDynamicTopRes(cell_id); |
|
|
ClearDynamicTopRes(cell_id); |
|
|
MS_LOG(INFO) << "Gradgraph already compiled"; |
|
|
MS_LOG(INFO) << "Gradgraph already compiled"; |
|
|
return; |
|
|
return; |
|
|
@@ -2367,7 +2399,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
ExecuteAction(resource); |
|
|
ExecuteAction(resource); |
|
|
ClearUselessRes(df_builder, cell, cell_id); |
|
|
ClearUselessRes(df_builder, cell, cell_id); |
|
|
UpdateCellGraph(cell, curr_g_, cell_id, false, true); |
|
|
UpdateCellGraph(cell, curr_g_, cell_id, false, true); |
|
|
UpdateTopCellCompileInfo(cell_id, true); |
|
|
|
|
|
|
|
|
UpdateTopCellInfo(cell_id, true); |
|
|
resource->Clean(); |
|
|
resource->Clean(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -2710,9 +2742,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & |
|
|
|
|
|
|
|
|
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { |
|
|
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { |
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
bool already_run = CheckCellGraph(cell_id); |
|
|
|
|
|
MS_LOG(DEBUG) << "Graph have already run " << already_run << " cell id " << cell_id; |
|
|
|
|
|
return BaseRefToPyData(already_run); |
|
|
|
|
|
|
|
|
bool forward_run = CheckCellGraph(cell_id) && top_cell_id_ == cell_id; |
|
|
|
|
|
MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ " |
|
|
|
|
|
<< top_cell_id_; |
|
|
|
|
|
return BaseRefToPyData(forward_run); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { |
|
|
py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { |
|
|
|