|
|
|
@@ -1439,6 +1439,12 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & |
|
|
|
return cell_id; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR(filename, graph); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::IsNotNestedGrad() const { |
|
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_; |
|
|
|
return grad_order_ <= 1; |
|
|
|
@@ -1851,6 +1857,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
curr_g_->set_output(output_node); |
|
|
|
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); |
|
|
|
if (EndBpropGraph(cell_id)) { |
|
|
|
MS_LOG(DEBUG) << "Get bprop function cell"; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto resource = GetResource(cell_id); |
|
|
|
@@ -1875,13 +1882,9 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("before_resolve.ir", newfg); |
|
|
|
} |
|
|
|
DumpGraphIR("before_resolve.ir", newfg); |
|
|
|
parse::ResolveFuncGraph(newfg, resource); |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("after_resolve.ir", newfg); |
|
|
|
} |
|
|
|
DumpGraphIR("after_resolve.ir", newfg); |
|
|
|
resource->set_func_graph(newfg); |
|
|
|
PopGraphStack(); |
|
|
|
} |
|
|
|
@@ -1907,10 +1910,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
it->is_grad = is_grad; |
|
|
|
it->fg = g; |
|
|
|
MS_LOG(DEBUG) << "Update bprop bg"; |
|
|
|
MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; |
|
|
|
} else { |
|
|
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func)); |
|
|
|
auto bprop_func_cell_id = GetId(bprop_func); |
|
|
|
MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id; |
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
} |
|
|
|
return; |
|
|
|
@@ -1959,13 +1964,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG |
|
|
|
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); |
|
|
|
} |
|
|
|
} |
|
|
|
// Obtain grad graph |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("fg.ir", g); |
|
|
|
} |
|
|
|
DumpGraphIR("fg.ir", g); |
|
|
|
auto is_top = IsTopGraph(cell_id); |
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top; |
|
|
|
set_need_replace_forward(IsNotNestedGrad()); |
|
|
|
// Obtain grad graph |
|
|
|
auto newfg = ad::Grad(g, r, is_top); |
|
|
|
|
|
|
|
if (is_custom_bprop) { |
|
|
|
@@ -2039,11 +2042,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
|
auto args_spec = GetArgsSpec(args, df_builder); |
|
|
|
resource->set_args_spec(args_spec); |
|
|
|
// Get real grad graph |
|
|
|
DumpGraphIR("before_grad.ir", resource->func_graph()); |
|
|
|
GradGraph(resource->func_graph(), grad, w_args, size, cell_id); |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("before_grad.ir", resource->func_graph()); |
|
|
|
DumpIR("after_grad.ir", df_builder); |
|
|
|
} |
|
|
|
DumpGraphIR("after_grad.ir", df_builder); |
|
|
|
resource->set_func_graph(df_builder); |
|
|
|
resource->manager()->KeepRoots({df_builder}); |
|
|
|
resource->results()[pipeline::kBackend] = compile::CreateBackend(); |
|
|
|
@@ -2127,30 +2128,35 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph); |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("nested_bprop.ir", forward_graph); |
|
|
|
} |
|
|
|
DumpGraphIR("nested_bprop.ir", forward_graph); |
|
|
|
// Custom bprop get backward graph(before opt), which use like other forward graph |
|
|
|
curr_g_ = forward_graph; |
|
|
|
resource->set_func_graph(forward_graph); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Copy weights |
|
|
|
std::vector<AnfNodePtr> weights_params{}; |
|
|
|
// Copy weights parameters |
|
|
|
resource->manager()->AddFuncGraph(forward_graph); |
|
|
|
auto manager = Manage({forward_graph}, false); |
|
|
|
for (const auto &it : graph_info_map_.at(forward_graph).params) { |
|
|
|
if (it.second->has_default()) { |
|
|
|
weights_params.emplace_back(it.second); |
|
|
|
graph_info_map_.at(df_builder).params.emplace(it.first, it.second); |
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second); |
|
|
|
if (!it.second->has_default()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size(); |
|
|
|
df_builder->set_parameters(weights_params); |
|
|
|
resource->manager()->AddFuncGraph(forward_graph); |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("nested_fg.ir", forward_graph); |
|
|
|
} |
|
|
|
auto new_param = df_builder->add_parameter(); |
|
|
|
new_param->set_abstract(it.second->abstract()); |
|
|
|
new_param->set_name(it.second->name()); |
|
|
|
new_param->set_default_param(it.second->default_param()); |
|
|
|
ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope; |
|
|
|
new_param->set_scope(scope); |
|
|
|
manager->Replace(it.second, new_param); |
|
|
|
replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); |
|
|
|
MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name(); |
|
|
|
|
|
|
|
graph_info_map_.at(df_builder).params[it.first] = new_param; |
|
|
|
SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); |
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); |
|
|
|
} |
|
|
|
DumpGraphIR("nested_fg.ir", forward_graph); |
|
|
|
set_need_replace_forward(false); |
|
|
|
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args); |
|
|
|
resource->set_func_graph(newfg); |
|
|
|
@@ -2396,15 +2402,18 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg |
|
|
|
MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get(); |
|
|
|
auto newfg = resource->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(newfg); |
|
|
|
auto size = args.size(); |
|
|
|
auto inputs_size = args.size(); |
|
|
|
if (has_sens) { |
|
|
|
size -= 1; |
|
|
|
inputs_size -= 1; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(newfg)); |
|
|
|
for (size_t i = 0; i < size; ++i) { |
|
|
|
for (size_t i = 0; i < inputs_size; ++i) { |
|
|
|
inputs.emplace_back(GetInput(args[i], false)); |
|
|
|
} |
|
|
|
if (newfg->parameters().size() > inputs_size) { |
|
|
|
SetNestedWeigthsParam(newfg, cell_id, &inputs); |
|
|
|
} |
|
|
|
auto out_id = GetId(out); |
|
|
|
auto cnode = graph_prev->NewCNode(inputs); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode); |
|
|
|
@@ -2412,6 +2421,38 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg |
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, |
|
|
|
std::vector<AnfNodePtr> *inputs) { |
|
|
|
FuncGraphPtr forward_graph = nullptr; |
|
|
|
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), |
|
|
|
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (ic != cell_graph_list_.end()) { |
|
|
|
forward_graph = ic->fg; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph); |
|
|
|
auto params = newfg->parameters(); |
|
|
|
auto manage = Manage({newfg}, false); |
|
|
|
for (const auto &it : params) { |
|
|
|
auto param = it->cast<ParameterPtr>(); |
|
|
|
if (!param->has_default()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto ir = replace_weights_map_.find(forward_graph); |
|
|
|
if (ir == replace_weights_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map"; |
|
|
|
} |
|
|
|
for (const auto &ip : ir->second) { |
|
|
|
MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name(); |
|
|
|
if (ip.second->name() == param->name()) { |
|
|
|
manage->Replace(param, ip.first); |
|
|
|
inputs->emplace_back(ip.first); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
replace_weights_map_.erase(forward_graph); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &cell_id) { |
|
|
|
if (cell_id.empty()) { |
|
|
|
Clean(); |
|
|
|
@@ -2461,6 +2502,7 @@ void PynativeExecutor::ClearRes() { |
|
|
|
|
|
|
|
graph_info_map_.clear(); |
|
|
|
cell_sw_map_.clear(); |
|
|
|
replace_weights_map_.clear(); |
|
|
|
cell_graph_list_.clear(); |
|
|
|
top_cell_list_.clear(); |
|
|
|
op_index_map_.clear(); |
|
|
|
|