|
|
|
@@ -621,13 +621,14 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { |
|
|
|
auto op_name = py::cast<std::string>(args[PY_NAME]); |
|
|
|
op_exec_info->op_name = op_name; |
|
|
|
if (grad_flag()) { |
|
|
|
auto resource = GetResource(); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); |
|
|
|
int64_t graph_id = graph_id_; |
|
|
|
auto it = resource->results().find(pipeline::kPynativeGraphId); |
|
|
|
if (it != resource->results().end()) { |
|
|
|
graph_id = it->second.cast<int64_t>(); |
|
|
|
auto resource = GetResource(); |
|
|
|
if (resource != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); |
|
|
|
auto it = resource->results().find(pipeline::kPynativeGraphId); |
|
|
|
if (it != resource->results().end()) { |
|
|
|
graph_id = it->second.cast<int64_t>(); |
|
|
|
} |
|
|
|
} |
|
|
|
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); |
|
|
|
op_index_map_[op_name]++; |
|
|
|
@@ -686,7 +687,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
|
|
|
|
if (need_construct_graph()) { |
|
|
|
AnfNodePtr input_node = nullptr; |
|
|
|
if (!graph_info_map_.empty()) { |
|
|
|
if (!graph_info_map_.empty() && !top_cell_list_.empty()) { |
|
|
|
input_node = GetInput(obj, op_mask); |
|
|
|
} |
|
|
|
// update abstract |
|
|
|
@@ -1450,7 +1451,7 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { |
|
|
|
|
|
|
|
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { |
|
|
|
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1466,6 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::ClearResidualRes() { |
|
|
|
if (top_cell_list_.empty() && !graph_stack_.empty()) { |
|
|
|
graph_id_ = 0; |
|
|
|
graph_info_map_.clear(); |
|
|
|
cell_sw_map_.clear(); |
|
|
|
cell_graph_list_.clear(); |
|
|
|
top_cell_list_.clear(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { |
|
|
|
// Cell is empty, get nearest dfbuilder |
|
|
|
if (cell_id.empty() && !top_cell_list_.empty()) { |
|
|
|
@@ -1490,6 +1501,10 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { |
|
|
|
return it.df_builder; |
|
|
|
} |
|
|
|
} |
|
|
|
// Current cell is not top graph, get first top cell |
|
|
|
if (!top_cell_list_.empty()) { |
|
|
|
return top_cell_list_.front().df_builder; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1703,7 +1718,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
// init resource for constructing forward graph and grad graph |
|
|
|
auto g = std::make_shared<FuncGraph>(); |
|
|
|
curr_g_ = g; |
|
|
|
if (graph_stack_.empty()) { |
|
|
|
ClearResidualRes(); |
|
|
|
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { |
|
|
|
MakeNewTopGraph(cell_id, args, g); |
|
|
|
} |
|
|
|
PushCurrentGraphToStack(); |
|
|
|
@@ -1724,16 +1740,6 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
dynamic_cell_ = IsDynamicCell(cell); |
|
|
|
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; |
|
|
|
} |
|
|
|
// Make bprop graph |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
MS_LOG(INFO) << "Make bprop graph"; |
|
|
|
it->custom_bprop_graph = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { |
|
|
|
@@ -1807,7 +1813,7 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
|
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; |
|
|
|
if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) { |
|
|
|
MS_LOG(INFO) << "Endgraph already compiled"; |
|
|
|
@@ -1841,35 +1847,30 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
AnfNodePtr output_node = GetObjNode(out, out_id); |
|
|
|
curr_g_->set_output(output_node); |
|
|
|
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString(); |
|
|
|
if (EndBpropGraph(cell_id)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto resource = GetResource(cell_id); |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
auto is_bprop_graph = IsBpropGraph(cell_id); |
|
|
|
auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
if (!is_bprop_cell || !is_bprop_graph) { |
|
|
|
resource->manager()->AddFuncGraph(curr_g_); |
|
|
|
} |
|
|
|
if (!is_bprop_cell) { |
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false); |
|
|
|
} |
|
|
|
resource->manager()->AddFuncGraph(curr_g_); |
|
|
|
UpdateCellGraph(cell, curr_g_, cell_id, true, false); |
|
|
|
auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); |
|
|
|
|
|
|
|
if (graph_stack_.size() > 1) { |
|
|
|
if (!is_bprop_cell || !is_bprop_graph) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
|
|
|
|
PopGraphStack(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_stack_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.emplace_back(input); |
|
|
|
} |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
PopGraphStack(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_stack_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.emplace_back(input); |
|
|
|
} |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("before_resolve.ir", newfg); |
|
|
|
@@ -1883,6 +1884,17 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::EndBpropGraph(const string &cell_id) { |
|
|
|
auto is_bprop_graph = IsBpropGraph(cell_id); |
|
|
|
if (is_bprop_graph) { |
|
|
|
if (IsNotNestedGrad()) { |
|
|
|
PopGraphStack(); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
@@ -1894,7 +1906,8 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
it->fg = g; |
|
|
|
MS_LOG(DEBUG) << "Update bprop bg"; |
|
|
|
} else { |
|
|
|
auto cell_info = CellInfo(false, true, false, g, cell_id); |
|
|
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func)); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
} |
|
|
|
return; |
|
|
|
@@ -1923,13 +1936,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Add new cell graph " << cell_id; |
|
|
|
auto cell_info = CellInfo(false, true, false, tmp, cell_id); |
|
|
|
auto cell_info = CellInfo(false, true, tmp, cell_id, ""); |
|
|
|
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, |
|
|
|
const string &cell_id, const py::args &args) { |
|
|
|
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id); |
|
|
|
const std::string &cell_id, const py::args &args) { |
|
|
|
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
if (is_custom_bprop) { |
|
|
|
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); |
|
|
|
if (par_number > 0) { |
|
|
|
@@ -1943,17 +1956,15 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG |
|
|
|
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); |
|
|
|
} |
|
|
|
} |
|
|
|
FuncGraphPtr newfg = nullptr; |
|
|
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) { |
|
|
|
// Obtain grad graph |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("fg.ir", g); |
|
|
|
} |
|
|
|
auto is_top = IsTopGraph(cell_id); |
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top; |
|
|
|
set_need_replace_forward(IsNotNestedGrad()); |
|
|
|
newfg = ad::Grad(g, r, is_top); |
|
|
|
// Obtain grad graph |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("fg.ir", g); |
|
|
|
} |
|
|
|
auto is_top = IsTopGraph(cell_id); |
|
|
|
MS_LOG(DEBUG) << "Grad top cell " << is_top; |
|
|
|
set_need_replace_forward(IsNotNestedGrad()); |
|
|
|
auto newfg = ad::Grad(g, r, is_top); |
|
|
|
|
|
|
|
if (is_custom_bprop) { |
|
|
|
auto params = newfg->parameters(); |
|
|
|
auto manager = Manage({newfg}, false); |
|
|
|
@@ -2135,7 +2146,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args |
|
|
|
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get wights params size " << weights_params.size(); |
|
|
|
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size(); |
|
|
|
df_builder->set_parameters(weights_params); |
|
|
|
resource->manager()->AddFuncGraph(forward_graph); |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
@@ -2314,7 +2325,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & |
|
|
|
MS_LOG(DEBUG) << "Grad not running yet"; |
|
|
|
return BaseRefToPyData(ret); |
|
|
|
} |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
|
string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); |
|
|
|
MS_LOG(DEBUG) << "Key is " << key; |
|
|
|
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { |
|
|
|
@@ -2379,12 +2390,6 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob |
|
|
|
MS_LOG(DEBUG) << "No nested bprop grad find"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { |
|
|
|
return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos; |
|
|
|
}); |
|
|
|
if (it != cell_graph_list_.end()) { |
|
|
|
MS_LOG(DEBUG) << "Make bprop graph end"; |
|
|
|
} |
|
|
|
auto out_id = GetId(out); |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
@@ -2489,6 +2494,7 @@ void PynativeExecutor::Clean() { |
|
|
|
void PynativeExecutor::ClearRes() { |
|
|
|
MS_LOG(DEBUG) << "Clear all res"; |
|
|
|
Clean(); |
|
|
|
graph_id_ = 0; |
|
|
|
grad_order_ = 0; |
|
|
|
grad_flag_ = false; |
|
|
|
dynamic_cell_ = false; |
|
|
|
|