From 020bd96976576499a6b90f36261e26efcd94bf7f Mon Sep 17 00:00:00 2001 From: zjun Date: Tue, 22 Dec 2020 09:47:28 +0800 Subject: [PATCH] Fix pynative parameters seoncd derivative Signed-off-by: zjun --- .../pipeline/pynative/pynative_execute.cc | 112 ++++++++++++------ .../pipeline/pynative/pynative_execute.h | 13 +- 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 8573cdc586..ab24aeafd8 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1439,6 +1439,12 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & return cell_id; } +void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) { + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR(filename, graph); + } +} + bool PynativeExecutor::IsNotNestedGrad() const { MS_LOG(DEBUG) << "Grad nested count is " << grad_order_; return grad_order_ <= 1; @@ -1851,6 +1857,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string curr_g_->set_output(output_node); MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); if (EndBpropGraph(cell_id)) { + MS_LOG(DEBUG) << "Get bprop function cell"; return; } auto resource = GetResource(cell_id); @@ -1875,13 +1882,9 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); } else { - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("before_resolve.ir", newfg); - } + DumpGraphIR("before_resolve.ir", newfg); parse::ResolveFuncGraph(newfg, resource); - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("after_resolve.ir", newfg); - } + DumpGraphIR("after_resolve.ir", newfg); resource->set_func_graph(newfg); PopGraphStack(); } @@ -1907,10 +1910,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt if (it != cell_graph_list_.end()) { it->is_grad = is_grad; it->fg = g; - MS_LOG(DEBUG) << "Update bprop bg"; + MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; } else { py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); - auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func)); + auto bprop_func_cell_id = GetId(bprop_func); + MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id; + auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id); cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); } return; @@ -1959,13 +1964,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); } } - // Obtain grad graph - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("fg.ir", g); - } + DumpGraphIR("fg.ir", g); auto is_top = IsTopGraph(cell_id); MS_LOG(DEBUG) << "Grad top cell " << is_top; set_need_replace_forward(IsNotNestedGrad()); + // Obtain grad graph auto newfg = ad::Grad(g, r, is_top); if (is_custom_bprop) { @@ -2039,11 +2042,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje auto args_spec = GetArgsSpec(args, df_builder); resource->set_args_spec(args_spec); // Get real grad graph + DumpGraphIR("before_grad.ir", resource->func_graph()); GradGraph(resource->func_graph(), grad, w_args, size, cell_id); - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("before_grad.ir", resource->func_graph()); - DumpIR("after_grad.ir", df_builder); - } + DumpGraphIR("after_grad.ir", df_builder); resource->set_func_graph(df_builder); resource->manager()->KeepRoots({df_builder}); resource->results()[pipeline::kBackend] = compile::CreateBackend(); @@ -2127,30 +2128,35 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args } MS_EXCEPTION_IF_NULL(forward_graph); if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("nested_bprop.ir", forward_graph); - } + DumpGraphIR("nested_bprop.ir", forward_graph); // Custom bprop get backward graph(before opt), which use like other forward graph curr_g_ = forward_graph; resource->set_func_graph(forward_graph); return; } - // Copy weights - std::vector weights_params{}; + // Copy weights parameters + resource->manager()->AddFuncGraph(forward_graph); + auto manager = Manage({forward_graph}, false); for (const auto &it : graph_info_map_.at(forward_graph).params) { - if (it.second->has_default()) { - weights_params.emplace_back(it.second); - graph_info_map_.at(df_builder).params.emplace(it.first, it.second); - SetNodeMapInGraphInfoMap(df_builder, it.first, it.second); + if (!it.second->has_default()) { + continue; } - } - MS_LOG(DEBUG) << "Get weights params size " << weights_params.size(); - df_builder->set_parameters(weights_params); - resource->manager()->AddFuncGraph(forward_graph); - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("nested_fg.ir", forward_graph); - } + auto new_param = df_builder->add_parameter(); + new_param->set_abstract(it.second->abstract()); + new_param->set_name(it.second->name()); + new_param->set_default_param(it.second->default_param()); + ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope; + new_param->set_scope(scope); + manager->Replace(it.second, new_param); + replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); + MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name(); + + graph_info_map_.at(df_builder).params[it.first] = new_param; + SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); + SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); + } + DumpGraphIR("nested_fg.ir", forward_graph); set_need_replace_forward(false); auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args); resource->set_func_graph(newfg); @@ -2396,15 +2402,18 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get(); auto newfg = resource->func_graph(); MS_EXCEPTION_IF_NULL(newfg); - auto size = args.size(); + auto inputs_size = args.size(); if (has_sens) { - size -= 1; + inputs_size -= 1; } std::vector inputs; inputs.emplace_back(NewValueNode(newfg)); - for (size_t i = 0; i < size; ++i) { + for (size_t i = 0; i < inputs_size; ++i) { inputs.emplace_back(GetInput(args[i], false)); } + if (newfg->parameters().size() > inputs_size) { + SetNestedWeigthsParam(newfg, cell_id, &inputs); + } auto out_id = GetId(out); auto cnode = graph_prev->NewCNode(inputs); SetTupleArgsToGraphInfoMap(graph_prev, out, cnode); @@ -2412,6 +2421,38 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); } +void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, + std::vector *inputs) { + FuncGraphPtr forward_graph = nullptr; + auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); + if (ic != cell_graph_list_.end()) { + forward_graph = ic->fg; + } + MS_EXCEPTION_IF_NULL(forward_graph); + auto params = newfg->parameters(); + auto manage = Manage({newfg}, false); + for (const auto &it : params) { + auto param = it->cast(); + if (!param->has_default()) { + continue; + } + auto ir = replace_weights_map_.find(forward_graph); + if (ir == replace_weights_map_.end()) { + MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map"; + } + for (const auto &ip : ir->second) { + MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name(); + if (ip.second->name() == param->name()) { + manage->Replace(param, ip.first); + inputs->emplace_back(ip.first); + break; + } + } + } + replace_weights_map_.erase(forward_graph); +} + void PynativeExecutor::Clear(const std::string &cell_id) { if (cell_id.empty()) { Clean(); @@ -2461,6 +2502,7 @@ void PynativeExecutor::ClearRes() { graph_info_map_.clear(); cell_sw_map_.clear(); + replace_weights_map_.clear(); cell_graph_list_.clear(); top_cell_list_.clear(); op_index_map_.clear(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 559ed9c47f..12d854429a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -60,7 +60,7 @@ void ClearPyNativeSession(); struct GraphInfo { std::string cell_id; AnfNodePtr output; - std::unordered_map params; // hold input parameters and cell weigths + OrderedMap params; // hold input parameters and cell weigths std::unordered_map>> node_map; std::vector objects; GraphInfo() = default; @@ -210,6 +210,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); void ClearResidualRes(const std::string &cell_id); + void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); void NewGraphInner(const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); @@ -233,6 +234,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, const py::object &out, bool has_sens); + void SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); // Hold graph(forward and grad) info @@ -242,7 +244,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param = false); void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { - graph_info_map_[g].params.emplace(std::make_pair(id, param)); + graph_info_map_[g].params[id] = param; } void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, int64_t index = -1) { @@ -269,15 +271,16 @@ class PynativeExecutor : public std::enable_shared_from_this { // Records forwrad graph, the bottom is top graph std::stack graph_stack_; + // Use vector for keep order + std::vector cell_graph_list_; + std::vector top_cell_list_; std::unordered_set cell_input_args_; std::unordered_map cell_dynamic_map_; // Record all info for all cells std::unordered_map graph_info_map_; - // Use vector for keep order - std::vector cell_graph_list_; - std::vector top_cell_list_; // key: cell_id, value: (send_id, weighs_id), cache for sens and weight change std::unordered_map> cell_sw_map_; + std::unordered_map>> replace_weights_map_; // Used for runop and replace forward result of grad graph std::unordered_map op_index_map_;