From c0f02dfaf8cc64ec27b7cf94c726d9978c486f75 Mon Sep 17 00:00:00 2001 From: zjun Date: Fri, 18 Dec 2020 12:42:46 +0800 Subject: [PATCH] Opitimize pynative bprop Signed-off-by: zjun --- .../pipeline/jit/parse/data_converter.cc | 2 +- .../pipeline/pynative/pynative_execute.cc | 134 +++++++++--------- .../pipeline/pynative/pynative_execute.h | 20 +-- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 17 +-- 4 files changed, 88 insertions(+), 85 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 9df5422ee5..7607fdea19 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -46,7 +46,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) { auto bprop_graph = std::make_shared(); std::vector outputs; - auto fake_bprop = std::make_shared("bprop_cut", obj); + auto fake_bprop = std::make_shared("bprop_cut", py::object()); fake_bprop->set_hook(bprop_func); (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); outputs.push_back(NewValueNode(fake_bprop)); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 998b7fef37..2ba8ff0f01 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -621,13 +621,14 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { auto op_name = py::cast(args[PY_NAME]); op_exec_info->op_name = op_name; if (grad_flag()) { - auto resource = GetResource(); - MS_EXCEPTION_IF_NULL(resource); - MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); int64_t graph_id = graph_id_; - auto it = resource->results().find(pipeline::kPynativeGraphId); - if (it != resource->results().end()) { - graph_id = it->second.cast(); + auto resource = GetResource(); + if (resource != nullptr) { + MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); + auto it = resource->results().find(pipeline::kPynativeGraphId); + if (it != resource->results().end()) { + graph_id = it->second.cast(); + } } op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); op_index_map_[op_name]++; @@ -686,7 +687,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v if (need_construct_graph()) { AnfNodePtr input_node = nullptr; - if (!graph_info_map_.empty()) { + if (!graph_info_map_.empty() && !top_cell_list_.empty()) { input_node = GetInput(obj, op_mask); } // update abstract @@ -1450,7 +1451,7 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { - return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; + return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos; }); } @@ -1466,6 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) }); } +void PynativeExecutor::ClearResidualRes() { + if (top_cell_list_.empty() && !graph_stack_.empty()) { + graph_id_ = 0; + graph_info_map_.clear(); + cell_sw_map_.clear(); + cell_graph_list_.clear(); + top_cell_list_.clear(); + } +} + FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { // Cell is empty, get nearest dfbuilder if (cell_id.empty() && !top_cell_list_.empty()) { @@ -1490,6 +1501,10 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { return it.df_builder; } } + // Current cell is not top graph, get first top cell + if (!top_cell_list_.empty()) { + return top_cell_list_.front().df_builder; + } return nullptr; } @@ -1703,7 +1718,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg // init resource for constructing forward graph and grad graph auto g = std::make_shared(); curr_g_ = g; - if (graph_stack_.empty()) { + ClearResidualRes(); + if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { MakeNewTopGraph(cell_id, args, g); } PushCurrentGraphToStack(); @@ -1724,16 +1740,6 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg dynamic_cell_ = IsDynamicCell(cell); MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; } - // Make bprop graph - if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { - return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; - }); - if (it != cell_graph_list_.end()) { - MS_LOG(INFO) << "Make bprop graph"; - it->custom_bprop_graph = true; - } - } } void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { @@ -1807,7 +1813,7 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con } void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { - auto cell_id = GetCellId(cell, args); + const auto &cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) { MS_LOG(INFO) << "Endgraph already compiled"; @@ -1841,35 +1847,30 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string AnfNodePtr output_node = GetObjNode(out, out_id); curr_g_->set_output(output_node); MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); + if (EndBpropGraph(cell_id)) { + return; + } auto resource = GetResource(cell_id); MS_EXCEPTION_IF_NULL(resource); - auto is_bprop_graph = IsBpropGraph(cell_id); - auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); - if (!is_bprop_cell || !is_bprop_graph) { - resource->manager()->AddFuncGraph(curr_g_); - } - if (!is_bprop_cell) { - UpdateCellGraph(cell, curr_g_, cell_id, true, false); - } + resource->manager()->AddFuncGraph(curr_g_); + UpdateCellGraph(cell, curr_g_, cell_id, true, false); auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); if (graph_stack_.size() > 1) { - if (!is_bprop_cell || !is_bprop_graph) { - std::vector inputs; - inputs.emplace_back(NewValueNode(curr_g_)); + std::vector inputs; + inputs.emplace_back(NewValueNode(curr_g_)); - PopGraphStack(); - // connect the previous graph to the inside graph - auto graph_prev = graph_stack_.top(); - for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], false); - inputs.emplace_back(input); - } - auto out_cnode = graph_prev->NewCNode(inputs); - SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); - SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); - SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); + PopGraphStack(); + // connect the previous graph to the inside graph + auto graph_prev = graph_stack_.top(); + for (size_t i = 0; i < args.size(); i++) { + auto input = GetInput(args[i], false); + inputs.emplace_back(input); } + auto out_cnode = graph_prev->NewCNode(inputs); + SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); + SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); + SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); } else { if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { DumpIR("before_resolve.ir", newfg); @@ -1883,6 +1884,17 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string } } +bool PynativeExecutor::EndBpropGraph(const string &cell_id) { + auto is_bprop_graph = IsBpropGraph(cell_id); + if (is_bprop_graph) { + if (IsNotNestedGrad()) { + PopGraphStack(); + } + return true; + } + return false; +} + void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad) { if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { @@ -1894,7 +1906,8 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt it->fg = g; MS_LOG(DEBUG) << "Update bprop bg"; } else { - auto cell_info = CellInfo(false, true, false, g, cell_id); + py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); + auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func)); cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); } return; @@ -1923,13 +1936,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt return; } MS_LOG(DEBUG) << "Add new cell graph " << cell_id; - auto cell_info = CellInfo(false, true, false, tmp, cell_id); + auto cell_info = CellInfo(false, true, tmp, cell_id, ""); cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); } FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, - const string &cell_id, const py::args &args) { - bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id); + const std::string &cell_id, const py::args &args) { + bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); if (is_custom_bprop) { size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); if (par_number > 0) { @@ -1943,17 +1956,15 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); } } - FuncGraphPtr newfg = nullptr; - if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) { - // Obtain grad graph - if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - DumpIR("fg.ir", g); - } - auto is_top = IsTopGraph(cell_id); - MS_LOG(DEBUG) << "Grad top cell " << is_top; - set_need_replace_forward(IsNotNestedGrad()); - newfg = ad::Grad(g, r, is_top); + // Obtain grad graph + if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { + DumpIR("fg.ir", g); } + auto is_top = IsTopGraph(cell_id); + MS_LOG(DEBUG) << "Grad top cell " << is_top; + set_need_replace_forward(IsNotNestedGrad()); + auto newfg = ad::Grad(g, r, is_top); + if (is_custom_bprop) { auto params = newfg->parameters(); auto manager = Manage({newfg}, false); @@ -2135,7 +2146,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args SetNodeMapInGraphInfoMap(df_builder, it.first, it.second); } } - MS_LOG(DEBUG) << "Get wights params size " << weights_params.size(); + MS_LOG(DEBUG) << "Get weights params size " << weights_params.size(); df_builder->set_parameters(weights_params); resource->manager()->AddFuncGraph(forward_graph); if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { @@ -2314,7 +2325,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & MS_LOG(DEBUG) << "Grad not running yet"; return BaseRefToPyData(ret); } - auto cell_id = GetCellId(cell, args); + const auto &cell_id = GetCellId(cell, args); string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); MS_LOG(DEBUG) << "Key is " << key; for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { @@ -2379,12 +2390,6 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob MS_LOG(DEBUG) << "No nested bprop grad find"; return false; } - auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { - return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; - }); - if (it != cell_graph_list_.end()) { - MS_LOG(DEBUG) << "Make bprop graph end"; - } auto out_id = GetId(out); std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); @@ -2489,6 +2494,7 @@ void PynativeExecutor::Clean() { void PynativeExecutor::ClearRes() { MS_LOG(DEBUG) << "Clear all res"; Clean(); + graph_id_ = 0; grad_order_ = 0; grad_flag_ = false; dynamic_cell_ = false; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 7ed061617f..42b9a2c4e7 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -68,18 +68,18 @@ struct GraphInfo { }; struct CellInfo { - bool is_grad{false}; // Derivative is calculated - bool is_custom_bprop{false}; // Custom bprop - bool custom_bprop_graph{false}; // Custom bprop make forward graph - FuncGraphPtr fg; // Forward graph + bool is_grad{false}; // Derivative is calculated + bool is_custom_bprop{false}; // Custom bprop + FuncGraphPtr fg; // Forward graph std::string cell_id; + std::string bprop_cell_id; CellInfo() = default; - CellInfo(bool isgrad, bool custom_bprop, bool bprop_graph, FuncGraphPtr foward_graph, std::string cellid) + CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) : is_grad(isgrad), is_custom_bprop(custom_bprop), - custom_bprop_graph(bprop_graph), fg(std::move(foward_graph)), - cell_id(std::move(cellid)) {} + cell_id(std::move(cellid)), + bprop_cell_id(std::move(bprop_id)) {} }; struct TopCellInfo { @@ -187,13 +187,15 @@ class PynativeExecutor : public std::enable_shared_from_this { bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); + void ClearResidualRes(); void NewGraphInner(const py::object &cell, const py::args &args); void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, const std::string &out_id, const py::args &args); - FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, const string &cell_id, - const py::args &args); + bool EndBpropGraph(const string &cell_id); + FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, + const std::string &cell_id, const py::args &args); std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, py::object *sens = nullptr); void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 35feeb0e4f..d4788ef4c9 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -182,21 +182,16 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { auto inst = pynative::PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(inst); try { - inst->NewGraph(GetPyObj(), input_args.cast()); + MS_LOG(DEBUG) << "Run bprop function start"; + inst->NewGraph(hook_, input_args.cast()); py::object grads_obj = hook_(*convert_args); py::tuple grads = check_bprop_out(grads_obj, py_args); - inst->EndGraph(GetPyObj(), grads_obj, input_args.cast()); + inst->EndGraph(hook_, grads_obj, input_args.cast()); + MS_LOG(DEBUG) << "Run bprop function end"; return std::make_shared(grads); - } catch (const py::type_error &ex) { + } catch (std::exception &bt) { inst->ClearRes(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - inst->ClearRes(); - throw py::value_error(ex); - } catch (...) { - inst->ClearRes(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred in run bprop. Exception name: " << exName; + std::rethrow_exception(std::current_exception()); } } SyncData(py_args[2]);