|
|
|
@@ -624,7 +624,11 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { |
|
|
|
auto resource = GetResource(); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); |
|
|
|
int64_t graph_id = resource->results()[pipeline::kPynativeGraphId].cast<int64_t>(); |
|
|
|
int64_t graph_id = graph_id_; |
|
|
|
auto it = resource->results().find(pipeline::kPynativeGraphId); |
|
|
|
if (it != resource->results().end()) { |
|
|
|
graph_id = it->second.cast<int64_t>(); |
|
|
|
} |
|
|
|
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); |
|
|
|
op_index_map_[op_name]++; |
|
|
|
} |
|
|
|
@@ -943,8 +947,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
auto param_name = py::cast<std::string>(name_attr); |
|
|
|
auto df_builder = GetDfbuilder(); |
|
|
|
MS_EXCEPTION_IF_NULL(df_builder); |
|
|
|
if (graph_info_map_.at(df_builder).second.params.find(obj_id) == |
|
|
|
graph_info_map_.at(df_builder).second.params.end()) { |
|
|
|
if (graph_info_map_.at(df_builder).params.find(obj_id) == graph_info_map_.at(df_builder).params.end()) { |
|
|
|
auto free_param = df_builder->add_parameter(); |
|
|
|
free_param->set_name(param_name); |
|
|
|
free_param->debug_info()->set_name(param_name); |
|
|
|
@@ -957,12 +960,12 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param); |
|
|
|
return free_param; |
|
|
|
} |
|
|
|
node = graph_info_map_.at(df_builder).second.node_map[obj_id].first; |
|
|
|
MS_LOG(DEBUG) << "Get input node " << node->ToString() << obj_id; |
|
|
|
node = graph_info_map_.at(df_builder).node_map[obj_id].first; |
|
|
|
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id; |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
if (graph_info_map_.at(curr_g_).second.node_map.find(obj_id) != graph_info_map_.at(curr_g_).second.node_map.end()) { |
|
|
|
if (graph_info_map_.at(curr_g_).node_map.find(obj_id) != graph_info_map_.at(curr_g_).node_map.end()) { |
|
|
|
// op(x, y) |
|
|
|
// out = op(op1(x, y)) |
|
|
|
// out = op(cell1(x, y)) |
|
|
|
@@ -989,7 +992,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
node = MakeValueNode(obj, obj_id); |
|
|
|
} |
|
|
|
node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr" |
|
|
|
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << obj_id; |
|
|
|
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id; |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1077,14 +1080,14 @@ void PynativeExecutor::CleanTensorsInValueNode() { |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { |
|
|
|
auto &out = graph_info_map_.at(curr_g_).second.node_map[obj_id]; |
|
|
|
auto &out = graph_info_map_.at(curr_g_).node_map[obj_id]; |
|
|
|
if (out.second.size() == 1 && out.second[0] == -1) { |
|
|
|
return out.first; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Output size " << out.second.size(); |
|
|
|
|
|
|
|
// Params node |
|
|
|
if (graph_info_map_.at(curr_g_).second.params.find(obj_id) != graph_info_map_.at(curr_g_).second.params.end()) { |
|
|
|
if (graph_info_map_.at(curr_g_).params.find(obj_id) != graph_info_map_.at(curr_g_).params.end()) { |
|
|
|
auto para_node = out.first; |
|
|
|
for (auto &idx : out.second) { |
|
|
|
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, |
|
|
|
@@ -1441,35 +1444,50 @@ bool PynativeExecutor::IsNotNestedGrad() const { |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { |
|
|
|
return std::any_of( |
|
|
|
top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) { |
|
|
|
return value.first == cell_id; |
|
|
|
}); |
|
|
|
return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SubNestedGradCount() { |
|
|
|
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SubNestedGradOrder() { |
|
|
|
if (grad_order_ > 0) { |
|
|
|
--grad_order_; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id, is_grad](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) { |
|
|
|
return value.first == cell_id && (!is_grad || value.second.second); |
|
|
|
}); |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) { |
|
|
|
return value.cell_id == cell_id && (!is_grad || value.is_grad); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { |
|
|
|
// Cell is empty, get nearest dfbuilder |
|
|
|
if (cell_id.empty() && !top_cell_list_.empty()) { |
|
|
|
return top_cell_list_.back().second.second.first; |
|
|
|
if (top_cell_list_.size() == 1) { |
|
|
|
return top_cell_list_.begin()->df_builder; |
|
|
|
} |
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) { |
|
|
|
return top_cell_list_.back().df_builder; |
|
|
|
} |
|
|
|
if (top_cell_list_.size() < grad_order_) { |
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); |
|
|
|
// Grad order greater than 2 |
|
|
|
auto it = top_cell_list_.end(); |
|
|
|
std::advance(it, -2); |
|
|
|
return it->df_builder; |
|
|
|
} |
|
|
|
// If top graph hold |
|
|
|
for (const auto &it : top_cell_list_) { |
|
|
|
if (cell_id.find(it.first) != std::string::npos) { |
|
|
|
return it.second.second.first; |
|
|
|
if (cell_id.find(it.cell_id) != std::string::npos) { |
|
|
|
return it.df_builder; |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
@@ -1478,15 +1496,31 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { |
|
|
|
ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) { |
|
|
|
// Cell is empty, get nearest resource |
|
|
|
if (cell_id.empty() && !top_cell_list_.empty()) { |
|
|
|
return top_cell_list_.back().second.first; |
|
|
|
if (top_cell_list_.size() == 1) { |
|
|
|
return top_cell_list_.begin()->resource; |
|
|
|
} |
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) { |
|
|
|
return top_cell_list_.back().resource; |
|
|
|
} |
|
|
|
if (top_cell_list_.size() < grad_order_) { |
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); |
|
|
|
// Grad order greater than 2 |
|
|
|
auto it = top_cell_list_.end(); |
|
|
|
std::advance(it, -2); |
|
|
|
return it->resource; |
|
|
|
} |
|
|
|
for (const auto &it : top_cell_list_) { |
|
|
|
if (cell_id.find(it.first) != std::string::npos) { |
|
|
|
return it.second.first; |
|
|
|
if (cell_id.find(it.cell_id) != std::string::npos) { |
|
|
|
return it.resource; |
|
|
|
} |
|
|
|
} |
|
|
|
// Current cell is not top graph, get first top cell |
|
|
|
return top_cell_list_.front().second.first; |
|
|
|
if (!top_cell_list_.empty()) { |
|
|
|
return top_cell_list_.front().resource; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node, |
|
|
|
@@ -1674,7 +1708,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
} |
|
|
|
PushCurrentGraphToStack(); |
|
|
|
if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { |
|
|
|
graph_info_map_.emplace(curr_g_, std::make_pair(cell_id, GraphInfo())); |
|
|
|
GraphInfo graph_info = GraphInfo(cell_id); |
|
|
|
graph_info_map_.emplace(curr_g_, graph_info); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < args.size(); ++i) { |
|
|
|
auto param = args[i]; |
|
|
|
@@ -1682,13 +1717,23 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
std::string param_id = GetId(param); |
|
|
|
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); |
|
|
|
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, nullptr); |
|
|
|
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param); |
|
|
|
} |
|
|
|
// check whether the construct of cell will be changed |
|
|
|
if (!dynamic_cell_) { |
|
|
|
dynamic_cell_ = IsDynamicCell(cell); |
|
|
|
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; |
|
|
|
} |
|
|
|
// Make bprop graph |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
MS_LOG(INFO) << "Make bprop graph"; |
|
|
|
it->custom_bprop_graph = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { |
|
|
|
@@ -1701,11 +1746,8 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar |
|
|
|
} |
|
|
|
} |
|
|
|
// Clear runop pre |
|
|
|
auto it = std::find_if( |
|
|
|
top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) { |
|
|
|
return value.first == cell_id; |
|
|
|
}); |
|
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (it != top_cell_list_.end()) { |
|
|
|
top_cell_list_.erase(it); |
|
|
|
} |
|
|
|
@@ -1714,10 +1756,12 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar |
|
|
|
op_index_with_tensor_id_.clear(); |
|
|
|
|
|
|
|
auto df_builder = std::make_shared<FuncGraph>(); |
|
|
|
graph_info_map_.emplace(df_builder, std::make_pair(cell_id, GraphInfo())); |
|
|
|
GraphInfo graph_info = GraphInfo(cell_id); |
|
|
|
graph_info_map_.emplace(df_builder, graph_info); |
|
|
|
auto resource = std::make_shared<pipeline::Resource>(); |
|
|
|
resource->results()[pipeline::kPynativeGraphId] = graph_id_++; |
|
|
|
top_cell_list_.emplace_back(std::make_pair(cell_id, std::make_pair(resource, std::make_pair(df_builder, nullptr)))); |
|
|
|
auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id); |
|
|
|
top_cell_list_.emplace_back(top_cell_info); |
|
|
|
MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1730,8 +1774,10 @@ void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const p |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
for (int64_t i = 0; i < tuple_size; ++i) { |
|
|
|
auto id = GetId(tuple[i]); |
|
|
|
if (is_param) { |
|
|
|
SetParamNodeMapInGraphInfoMap(g, id, nullptr); |
|
|
|
if (is_param && node->isa<Parameter>()) { |
|
|
|
auto param = node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
SetParamNodeMapInGraphInfoMap(g, id, param); |
|
|
|
} |
|
|
|
SetNodeMapInGraphInfoMap(g, id, node, i); |
|
|
|
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param); |
|
|
|
@@ -1750,8 +1796,10 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con |
|
|
|
std::vector<int64_t> tmp = index_sequence; |
|
|
|
tmp.emplace_back(i); |
|
|
|
auto id = GetId(tuple[i]); |
|
|
|
if (is_param) { |
|
|
|
SetParamNodeMapInGraphInfoMap(g, id, nullptr); |
|
|
|
if (is_param && node->isa<Parameter>()) { |
|
|
|
auto param = node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
SetParamNodeMapInGraphInfoMap(g, id, param); |
|
|
|
} |
|
|
|
SetNodeMapInGraphInfoMap(g, id, node, tmp); |
|
|
|
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param); |
|
|
|
@@ -1767,7 +1815,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o |
|
|
|
} |
|
|
|
auto out_id = GetId(out); |
|
|
|
// x =op1, y =op2, return (x, y) |
|
|
|
if (graph_info_map_.at(curr_g_).second.node_map.find(out_id) == graph_info_map_.at(curr_g_).second.node_map.end()) { |
|
|
|
if (graph_info_map_.at(curr_g_).node_map.find(out_id) == graph_info_map_.at(curr_g_).node_map.end()) { |
|
|
|
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) { |
|
|
|
auto tuple = out.cast<py::tuple>(); |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
@@ -1795,27 +1843,33 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); |
|
|
|
auto resource = GetResource(cell_id); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
resource->manager()->AddFuncGraph(curr_g_); |
|
|
|
UpdateCellGraph(cell_id, true, false); |
|
|
|
|
|
|
|
set_need_replace_forward(IsNotNestedGrad()); |
|
|
|
auto newfg = MakeGradGraph(cell, args, curr_g_, resource, IsTopGraph(cell_id)); |
|
|
|
auto is_bprop_graph = IsBpropGraph(cell_id); |
|
|
|
auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
if (!is_bprop_cell || !is_bprop_graph) { |
|
|
|
resource->manager()->AddFuncGraph(curr_g_); |
|
|
|
} |
|
|
|
if (!is_bprop_cell) { |
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false); |
|
|
|
} |
|
|
|
auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); |
|
|
|
|
|
|
|
if (graph_stack_.size() > 1) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
|
|
|
|
PopGraphStack(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_stack_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.emplace_back(input); |
|
|
|
if (!is_bprop_cell || !is_bprop_graph) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
|
|
|
|
PopGraphStack(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_stack_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.emplace_back(input); |
|
|
|
} |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
} |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("before_resolve.ir", newfg); |
|
|
|
@@ -1829,40 +1883,54 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) { |
|
|
|
FuncGraphPtr tmp = curr_g_; |
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
// Bprop just save backward graph |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
it->is_grad = is_grad; |
|
|
|
it->fg = g; |
|
|
|
MS_LOG(DEBUG) << "Update bprop bg"; |
|
|
|
} else { |
|
|
|
auto cell_info = CellInfo(false, true, false, g, cell_id); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
FuncGraphPtr tmp = g; |
|
|
|
if (need_cloned && !IsNotNestedGrad()) { |
|
|
|
auto cloned_curr_g = BasicClone(curr_g_); |
|
|
|
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_); |
|
|
|
auto cloned_curr_g = BasicClone(g); |
|
|
|
graph_info_map_[cloned_curr_g] = graph_info_map_.at(g); |
|
|
|
tmp = cloned_curr_g; |
|
|
|
MS_LOG(DEBUG) << "Replace cur graph " << curr_g_.get() << " with cloned new " << cloned_curr_g.get(); |
|
|
|
MS_LOG(DEBUG) << "Replace cur graph " << g.get() << " with cloned new " << cloned_curr_g.get(); |
|
|
|
} |
|
|
|
for (auto &it : cell_graph_list_) { |
|
|
|
if (it.first != cell_id) { |
|
|
|
if (it.cell_id != cell_id) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
it.second.second = is_grad; |
|
|
|
it.is_grad = is_grad; |
|
|
|
if (need_cloned) { |
|
|
|
it.second.first = tmp; |
|
|
|
it.fg = tmp; |
|
|
|
} |
|
|
|
if (!need_cloned && !is_grad) { |
|
|
|
graph_info_map_[curr_g_] = graph_info_map_.at(it.second.first); |
|
|
|
graph_info_map_.erase(it.second.first); |
|
|
|
it.second.first = curr_g_; |
|
|
|
MS_LOG(DEBUG) << "Replace cur graph " << it.second.first.get() << " with new " << curr_g_.get(); |
|
|
|
graph_info_map_[g] = graph_info_map_.at(it.fg); |
|
|
|
graph_info_map_.erase(it.fg); |
|
|
|
it.fg = g; |
|
|
|
MS_LOG(DEBUG) << "Replace cur graph " << it.fg.get() << " with new " << g.get(); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Add new cell graph " << cell_id; |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), std::make_pair(cell_id, std::make_pair(tmp, false))); |
|
|
|
auto cell_info = CellInfo(false, true, false, tmp, cell_id); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::args &args, const FuncGraphPtr &g, |
|
|
|
const ResourcePtr &r, bool is_top) { |
|
|
|
// custom bprop debug |
|
|
|
bool need_replace_param = false; |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
need_replace_param = true; |
|
|
|
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, |
|
|
|
const string &cell_id, const py::args &args) { |
|
|
|
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id); |
|
|
|
if (is_custom_bprop) { |
|
|
|
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); |
|
|
|
if (par_number > 0) { |
|
|
|
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number |
|
|
|
@@ -1875,14 +1943,18 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a |
|
|
|
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); |
|
|
|
} |
|
|
|
} |
|
|
|
// Obtain grad graph |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("fg.ir", g); |
|
|
|
FuncGraphPtr newfg = nullptr; |
|
|
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) { |
|
|
|
// Obtain grad graph |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("fg.ir", g); |
|
|
|
} |
|
|
|
auto is_top = IsTopGraph(cell_id); |
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top; |
|
|
|
set_need_replace_forward(IsNotNestedGrad()); |
|
|
|
newfg = ad::Grad(g, r, is_top); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top; |
|
|
|
auto newfg = ad::Grad(g, r, is_top); |
|
|
|
|
|
|
|
if (need_replace_param) { |
|
|
|
if (is_custom_bprop) { |
|
|
|
auto params = newfg->parameters(); |
|
|
|
auto manager = Manage({newfg}, false); |
|
|
|
if (args.size() > params.size()) { |
|
|
|
@@ -1894,6 +1966,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a |
|
|
|
auto v_node = NewValueNode(value); |
|
|
|
manager->Replace(params[i], v_node); |
|
|
|
} |
|
|
|
UpdateCellGraph(cell, newfg, cell_id, false, false); |
|
|
|
} |
|
|
|
return newfg; |
|
|
|
} |
|
|
|
@@ -1965,12 +2038,12 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
|
resource->manager()->KeepRoots({df_builder}); |
|
|
|
resource->results()[pipeline::kBackend] = compile::CreateBackend(); |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start opt"; |
|
|
|
MS_LOG(INFO) << "Start opt"; |
|
|
|
PynativeOptimizeAction(resource); |
|
|
|
SaveTensorsInValueNode(resource); |
|
|
|
TaskEmitAction(resource); |
|
|
|
ExecuteAction(resource); |
|
|
|
UpdateCellGraph(cell_id, false, true); |
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, false, true); |
|
|
|
UpdateGraphInfoMap(cell_id); |
|
|
|
resource->Clean(); |
|
|
|
} |
|
|
|
@@ -2018,13 +2091,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
return; |
|
|
|
} |
|
|
|
ResourcePtr resource = nullptr; |
|
|
|
auto ia = std::find_if( |
|
|
|
top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) { |
|
|
|
return value.first == cell_id; |
|
|
|
}); |
|
|
|
auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (ia != top_cell_list_.end()) { |
|
|
|
resource = GetResource(ia->first); |
|
|
|
resource = GetResource(ia->cell_id); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
MS_LOG(DEBUG) << "Find old resource " << resource.get(); |
|
|
|
} |
|
|
|
@@ -2035,22 +2105,33 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
FuncGraphPtr df_builder = std::make_shared<FuncGraph>(); |
|
|
|
graph_info_map_.emplace(df_builder, std::make_pair(cell_id, GraphInfo())); |
|
|
|
top_cell_list_.emplace_back(std::make_pair(cell_id, std::make_pair(resource, std::make_pair(df_builder, nullptr)))); |
|
|
|
GraphInfo graph_info = GraphInfo(cell_id); |
|
|
|
graph_info_map_.emplace(df_builder, graph_info); |
|
|
|
auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id); |
|
|
|
top_cell_list_.emplace_back(top_cell_info); |
|
|
|
FuncGraphPtr forward_graph = nullptr; |
|
|
|
auto ib = std::find_if( |
|
|
|
cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) { return value.first == cell_id; }); |
|
|
|
auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (ib != cell_graph_list_.end()) { |
|
|
|
forward_graph = ib->second.first; |
|
|
|
forward_graph = ib->fg; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph); |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("nested_bprop.ir", forward_graph); |
|
|
|
} |
|
|
|
// Custom bprop get backward graph(before opt), which use like other forward graph |
|
|
|
curr_g_ = forward_graph; |
|
|
|
resource->set_func_graph(forward_graph); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Copy weights |
|
|
|
std::vector<AnfNodePtr> weights_params{}; |
|
|
|
for (const auto &it : graph_info_map_.at(forward_graph).second.params) { |
|
|
|
if (it.second != nullptr) { |
|
|
|
for (const auto &it : graph_info_map_.at(forward_graph).params) { |
|
|
|
if (it.second->has_default()) { |
|
|
|
weights_params.emplace_back(it.second); |
|
|
|
graph_info_map_.at(df_builder).second.params.emplace(it.first, it.second); |
|
|
|
graph_info_map_.at(df_builder).params.emplace(it.first, it.second); |
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -2061,7 +2142,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
DumpIR("nested_fg.ir", forward_graph); |
|
|
|
} |
|
|
|
set_need_replace_forward(false); |
|
|
|
auto newfg = MakeGradGraph(cell, args, forward_graph, resource, IsTopGraph(cell_id)); |
|
|
|
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args); |
|
|
|
resource->set_func_graph(newfg); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -2091,11 +2172,9 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh |
|
|
|
auto param = tuple[it]; |
|
|
|
auto param_id = GetId(param); |
|
|
|
AnfNodePtr para_node = nullptr; |
|
|
|
if (graph_info_map_.at(df_builder).second.params.find(param_id) != |
|
|
|
graph_info_map_.at(df_builder).second.params.end() && |
|
|
|
graph_info_map_.at(df_builder).second.node_map.find(param_id) != |
|
|
|
graph_info_map_.at(df_builder).second.node_map.end()) { |
|
|
|
para_node = graph_info_map_.at(df_builder).second.node_map[param_id].first; |
|
|
|
if (graph_info_map_.at(df_builder).params.find(param_id) != graph_info_map_.at(df_builder).params.end() && |
|
|
|
graph_info_map_.at(df_builder).node_map.find(param_id) != graph_info_map_.at(df_builder).node_map.end()) { |
|
|
|
para_node = graph_info_map_.at(df_builder).node_map[param_id].first; |
|
|
|
} else { |
|
|
|
auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name"); |
|
|
|
if (py::isinstance<py::none>(name_attr)) { |
|
|
|
@@ -2117,6 +2196,10 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh |
|
|
|
abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) { |
|
|
|
abstract::AbstractBasePtrList args_spec; |
|
|
|
std::size_t size = args.size(); |
|
|
|
auto df_params = df_builder->parameters(); |
|
|
|
if (df_params.size() < size) { |
|
|
|
MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size; |
|
|
|
} |
|
|
|
// input params |
|
|
|
for (std::size_t i = 0; i < size; i++) { |
|
|
|
ValuePtr converted = nullptr; |
|
|
|
@@ -2127,11 +2210,11 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args |
|
|
|
bool broaden = true; |
|
|
|
auto abs = abstract::FromValue(converted, broaden); |
|
|
|
args_spec.emplace_back(abs); |
|
|
|
auto param_node = std::static_pointer_cast<Parameter>(df_builder->parameters()[i]); |
|
|
|
auto param_node = std::static_pointer_cast<Parameter>(df_params[i]); |
|
|
|
param_node->set_abstract(abs); |
|
|
|
} |
|
|
|
// weights params |
|
|
|
for (const auto ¶m : df_builder->parameters()) { |
|
|
|
for (const auto ¶m : df_params) { |
|
|
|
auto param_node = std::static_pointer_cast<Parameter>(param); |
|
|
|
if (param_node->has_default()) { |
|
|
|
ValuePtr value = param_node->default_param(); |
|
|
|
@@ -2148,24 +2231,18 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args |
|
|
|
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder, |
|
|
|
const ResourcePtr &resource) { |
|
|
|
bool is_cloned = false; |
|
|
|
std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> r( |
|
|
|
std::make_pair(nullptr, std::make_pair(nullptr, nullptr))); |
|
|
|
auto it = std::find_if( |
|
|
|
top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) { |
|
|
|
return value.first == cell_id; |
|
|
|
}); |
|
|
|
if (it != top_cell_list_.end()) { |
|
|
|
r = it->second; |
|
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (it == top_cell_list_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Get top cell failed"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(r.first); |
|
|
|
if (r.second.second == nullptr) { |
|
|
|
if (it->bg == nullptr) { |
|
|
|
auto cloned_df_newfg = BasicClone(resource->func_graph()); |
|
|
|
r.second = std::make_pair(df_builder, cloned_df_newfg); |
|
|
|
it->bg = cloned_df_newfg; |
|
|
|
MS_LOG(DEBUG) << "Cloned df newfg"; |
|
|
|
is_cloned = false; |
|
|
|
} else { |
|
|
|
resource->set_func_graph(r.second.second); |
|
|
|
resource->set_func_graph(it->bg); |
|
|
|
MS_LOG(DEBUG) << "Used cloned df newfg"; |
|
|
|
} |
|
|
|
return is_cloned; |
|
|
|
@@ -2174,11 +2251,10 @@ bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraph |
|
|
|
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, |
|
|
|
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) { |
|
|
|
FuncGraphPtr top_g = nullptr; |
|
|
|
auto it = std::find_if( |
|
|
|
cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) { return value.first == cell_id; }); |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
top_g = it->second.first; |
|
|
|
top_g = it->fg; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(top_g); |
|
|
|
auto nparam = top_g->parameters().size(); |
|
|
|
@@ -2194,8 +2270,12 @@ void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr & |
|
|
|
|
|
|
|
auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(df)}; |
|
|
|
auto df_params = df_builder->parameters(); |
|
|
|
if (df_params.size() < arg_size) { |
|
|
|
MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < arg_size; ++i) { |
|
|
|
inputs.emplace_back(df_builder->parameters()[i]); |
|
|
|
inputs.emplace_back(df_params[i]); |
|
|
|
} |
|
|
|
auto out = df_builder->NewCNode(inputs); |
|
|
|
df_builder->set_output(out); |
|
|
|
@@ -2208,17 +2288,17 @@ void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) { |
|
|
|
bool index_find = false; |
|
|
|
for (const auto &it : cell_graph_list_) { |
|
|
|
if (index_find) { |
|
|
|
l.emplace_back(it.first); |
|
|
|
l.emplace_back(it.cell_id); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (it.first == cell_id) { |
|
|
|
if (it.cell_id == cell_id) { |
|
|
|
index_find = true; |
|
|
|
l.emplace_back(it.first); |
|
|
|
l.emplace_back(it.cell_id); |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &it : l) { |
|
|
|
for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) { |
|
|
|
if (ic->second.first.find(it) != std::string::npos) { |
|
|
|
if (ic->second.cell_id.find(it) != std::string::npos) { |
|
|
|
ic = graph_info_map_.erase(ic); |
|
|
|
} else { |
|
|
|
++ic; |
|
|
|
@@ -2229,7 +2309,7 @@ void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) { |
|
|
|
|
|
|
|
py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) { |
|
|
|
BaseRef ret = false; |
|
|
|
AddNestedGradCount(); |
|
|
|
AddNestedGradOrder(); |
|
|
|
if (!grad_running()) { |
|
|
|
MS_LOG(DEBUG) << "Grad not running yet"; |
|
|
|
return BaseRefToPyData(ret); |
|
|
|
@@ -2238,8 +2318,8 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & |
|
|
|
string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); |
|
|
|
MS_LOG(DEBUG) << "Key is " << key; |
|
|
|
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { |
|
|
|
MS_LOG(DEBUG) << "Cur cell id " << it->first; |
|
|
|
if (key != it->first.substr(0, std::min(PTR_LEN, it->first.size()))) { |
|
|
|
MS_LOG(DEBUG) << "Cur cell id " << it->cell_id; |
|
|
|
if (key != it->cell_id.substr(0, std::min(PTR_LEN, it->cell_id.size()))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Delete cellid from cell graph list"; |
|
|
|
@@ -2255,7 +2335,7 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, |
|
|
|
MS_LOG(DEBUG) << "Run start cell id " << cell_id; |
|
|
|
bool has_sens = false; |
|
|
|
for (const auto &it : top_cell_list_) { |
|
|
|
if (cell_id.find(it.first) != std::string::npos && cell_id != it.first) { |
|
|
|
if (cell_id.find(it.cell_id) != std::string::npos && cell_id != it.cell_id) { |
|
|
|
has_sens = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
@@ -2285,12 +2365,42 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, |
|
|
|
BaseRef value = (*run)(arg_list); |
|
|
|
CleanTensorsInValueNode(); |
|
|
|
set_grad_runing(false); |
|
|
|
MS_LOG(DEBUG) << "Run end " << value.ToString(); |
|
|
|
MS_LOG(DEBUG) << "Eval run end " << value.ToString(); |
|
|
|
auto out = BaseRefToPyData(value); |
|
|
|
if (MakeBpropNestedCnode(cell, out, cell_id)) { |
|
|
|
return out; |
|
|
|
} |
|
|
|
MakeNestedCnode(cell_id, args, resource, out, has_sens); |
|
|
|
return out; |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) { |
|
|
|
if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
MS_LOG(DEBUG) << "No nested bprop grad find"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
MS_LOG(DEBUG) << "Make bprop graph end"; |
|
|
|
} |
|
|
|
auto out_id = GetId(out); |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
PopGraphStack(); |
|
|
|
for (const auto &ig : graph_info_map_.at(curr_g_).params) { |
|
|
|
if (!ig.second->has_default()) { |
|
|
|
inputs.emplace_back(ig.second); |
|
|
|
} |
|
|
|
} |
|
|
|
auto cnode = curr_g_->NewCNode(inputs); |
|
|
|
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode); |
|
|
|
MS_LOG(DEBUG) << "Custom bprop make nested node is " << cnode->DebugString(4); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, |
|
|
|
const py::object &out, bool has_sens) { |
|
|
|
if (graph_stack_.empty()) { |
|
|
|
@@ -2313,15 +2423,15 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg |
|
|
|
} |
|
|
|
auto out_id = GetId(out); |
|
|
|
auto cnode = graph_prev->NewCNode(inputs); |
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode); |
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void MapClear(T *map, const std::string &flag) { |
|
|
|
void MapClear(T *map, const std::string &cell_id) { |
|
|
|
for (auto it = map->begin(); it != map->end();) { |
|
|
|
if (it->first.find(flag) != std::string::npos) { |
|
|
|
if (it->first.find(cell_id) != std::string::npos) { |
|
|
|
it = map->erase(it); |
|
|
|
} else { |
|
|
|
it++; |
|
|
|
@@ -2329,6 +2439,17 @@ void MapClear(T *map, const std::string &flag) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void VectorClear(T *vec, const std::string &cell_id) { |
|
|
|
for (auto it = vec->begin(); it != vec->end();) { |
|
|
|
if (it->cell_id.find(cell_id) != std::string::npos) { |
|
|
|
it = vec->erase(it); |
|
|
|
} else { |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &cell_id) { |
|
|
|
if (cell_id.empty()) { |
|
|
|
Clean(); |
|
|
|
@@ -2337,7 +2458,7 @@ void PynativeExecutor::Clear(const std::string &cell_id) { |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id; |
|
|
|
for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) { |
|
|
|
if (it->second.first.find(cell_id) != std::string::npos) { |
|
|
|
if (it->second.cell_id.find(cell_id) != std::string::npos) { |
|
|
|
it = graph_info_map_.erase(it); |
|
|
|
} else { |
|
|
|
++it; |
|
|
|
@@ -2351,14 +2472,14 @@ void PynativeExecutor::Clear(const std::string &cell_id) { |
|
|
|
ConfigManager::GetInstance().ResetIterNum(); |
|
|
|
MapClear<std::unordered_map<std::string, bool>>(&cell_dynamic_map_, cell_id); |
|
|
|
MapClear<std::unordered_map<std::string, std::pair<std::string, std::string>>>(&cell_sw_map_, cell_id); |
|
|
|
MapClear<std::vector<std::pair<std::string, std::pair<FuncGraphPtr, bool>>>>(&cell_graph_list_, cell_id); |
|
|
|
MapClear<std::vector<std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>>>>( |
|
|
|
&top_cell_list_, cell_id); |
|
|
|
VectorClear<std::vector<CellInfo>>(&cell_graph_list_, cell_id); |
|
|
|
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id); |
|
|
|
node_abs_map_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Clean() { |
|
|
|
MS_LOG(DEBUG) << "Clean"; |
|
|
|
SubNestedGradCount(); |
|
|
|
SubNestedGradOrder(); |
|
|
|
node_abs_map_.clear(); |
|
|
|
obj_to_forward_id_.clear(); |
|
|
|
ad::CleanRes(); |
|
|
|
|