|
|
|
@@ -1031,7 +1031,8 @@ PynativeExecutor::PynativeExecutor() { |
|
|
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
if (cell_graph_map_.count(cell_id) != 0) { |
|
|
|
// judge graph_context_.empty() to create sperate graphs except for the top |
|
|
|
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { |
|
|
|
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) { |
|
|
|
resource_ = cell_resource_map_[cell_id]; |
|
|
|
} |
|
|
|
@@ -1040,21 +1041,24 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
} |
|
|
|
|
|
|
|
auto g = std::make_shared<FuncGraph>(); |
|
|
|
|
|
|
|
if (top_g_ == nullptr) { |
|
|
|
if (graph_context_.empty()) { |
|
|
|
// a df builder is built for every top function graph |
|
|
|
df_builder_ = std::make_shared<FuncGraph>(); |
|
|
|
df_builder_map_[cell_id] = df_builder_; |
|
|
|
top_g_ = curr_g_ = g; |
|
|
|
resource_ = std::make_shared<pipeline::Resource>(); |
|
|
|
resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; |
|
|
|
cell_resource_map_[cell_id] = resource_; |
|
|
|
df_builder_ = std::make_shared<FuncGraph>(); |
|
|
|
MS_LOG(DEBUG) << "First new graph" << top_g_.get(); |
|
|
|
first_grad_step_ = true; |
|
|
|
top_graph_cells_.insert(cell_id); |
|
|
|
Pushp(); |
|
|
|
} else { |
|
|
|
Pushp(); |
|
|
|
if (df_builder_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "In NewGraphInner, got df builder is nullptr"; |
|
|
|
} |
|
|
|
curr_g_ = g; |
|
|
|
} |
|
|
|
Pushp(); |
|
|
|
if (graph_info_map_.count(g) == 0) { |
|
|
|
graph_info_map_[g] = GraphInfo(); |
|
|
|
} |
|
|
|
@@ -1171,22 +1175,25 @@ void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr &pa |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } |
|
|
|
void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } |
|
|
|
|
|
|
|
void PynativeExecutor::Popp() { |
|
|
|
if (graph_p_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; |
|
|
|
if (graph_context_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; |
|
|
|
} |
|
|
|
graph_context_.pop(); |
|
|
|
if (!graph_context_.empty()) { |
|
|
|
curr_g_ = graph_context_.top(); |
|
|
|
} |
|
|
|
curr_g_ = graph_p_.top(); |
|
|
|
graph_p_.pop(); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
if (cell_graph_map_.count(cell_id) != 0) { |
|
|
|
if (cell_graph_map_.count(cell_id) != 0 && graph_context_.empty()) { |
|
|
|
MS_LOG(DEBUG) << "Endgraph already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
cell_graph_map_[cell_id] = curr_g_; |
|
|
|
auto out_id = GetId(out); |
|
|
|
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { |
|
|
|
@@ -1246,7 +1253,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); |
|
|
|
} |
|
|
|
} |
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); |
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1); |
|
|
|
if (need_replace_param) { |
|
|
|
auto params = newfg->parameters(); |
|
|
|
auto manager = Manage({newfg}, false); |
|
|
|
@@ -1257,26 +1264,29 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
} |
|
|
|
} |
|
|
|
graph_info_map_.erase(curr_g_); |
|
|
|
if (curr_g_ != top_g_) { |
|
|
|
if (graph_context_.size() > 1) { |
|
|
|
Popp(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_context_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.push_back(input); |
|
|
|
} |
|
|
|
auto out_cnode = curr_g_->NewCNode(inputs); |
|
|
|
set_pyobj(curr_g_, GetCellId(cell, args)); |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
set_pyobj(graph_prev, GetCellId(cell, args)); |
|
|
|
if (py::isinstance<py::tuple>(out)) { |
|
|
|
auto out_list = py::cast<py::tuple>(out); |
|
|
|
auto out_size = static_cast<int>(out_list.size()); |
|
|
|
for (int i = 0; i < out_size; i++) { |
|
|
|
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); |
|
|
|
set_obj_node_map(graph_prev, GetId(out_list[i]), out_cnode, i); |
|
|
|
SetTupleOutput(out_list[i], out_cnode, std::vector<int>{i}); |
|
|
|
} |
|
|
|
} |
|
|
|
set_obj_node_map(curr_g_, GetId(out), out_cnode); |
|
|
|
set_obj_node_map(graph_prev, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
parse::ResolveFuncGraph(newfg, resource_); |
|
|
|
resource_->set_func_graph(newfg); |
|
|
|
Popp(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1348,14 +1358,36 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args |
|
|
|
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
|
const py::args &args) { |
|
|
|
MS_LOG(INFO) << "GradNet start" << args.size(); |
|
|
|
|
|
|
|
std::size_t size = args.size(); |
|
|
|
std::string cell_id = GetCellId(cell, args); |
|
|
|
if (graph_map_.count(cell_id) != 0) { |
|
|
|
MS_LOG(DEBUG) << "GradNet already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
size_t forward_args_count = args.size(); |
|
|
|
if (grad->sens_param()) { |
|
|
|
forward_args_count = forward_args_count - 1; |
|
|
|
} |
|
|
|
py::tuple forward_args(forward_args_count); |
|
|
|
for (size_t i = 0; i < forward_args_count; i++) { |
|
|
|
forward_args[i] = args[i]; |
|
|
|
} |
|
|
|
std::string forward_cell_id = GetCellId(cell, forward_args); |
|
|
|
MS_LOG(DEBUG) << "Forward cell_id:" << forward_cell_id; |
|
|
|
if (df_builder_map_.find(forward_cell_id) == df_builder_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find df builder"; |
|
|
|
} |
|
|
|
df_builder_ = df_builder_map_[forward_cell_id]; |
|
|
|
if (df_builder_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Got unexpected null df builder"; |
|
|
|
} |
|
|
|
|
|
|
|
if (cell_resource_map_.find(forward_cell_id) == cell_resource_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot find resource for " << forward_cell_id; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "GradNet first compiled"; |
|
|
|
resource_ = cell_resource_map_[forward_cell_id]; |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_params; |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
ParameterPtr p = std::make_shared<Parameter>(df_builder_); |
|
|
|
@@ -1368,6 +1400,10 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights); |
|
|
|
MS_EXCEPTION_IF_NULL(resource_->func_graph()); |
|
|
|
if (cell_graph_map_.find(forward_cell_id) == cell_graph_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id; |
|
|
|
} |
|
|
|
top_g_ = cell_graph_map_[forward_cell_id]; |
|
|
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size); |
|
|
|
resource_->set_func_graph(g); |
|
|
|
resource_->manager()->KeepRoots({g}); |
|
|
|
@@ -1409,6 +1445,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&graph_map_, flag); |
|
|
|
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&cell_graph_map_, flag); |
|
|
|
MapClear<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_, flag); |
|
|
|
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(&df_builder_map_, flag); |
|
|
|
Clean(); |
|
|
|
// Maybe exit in the pynative runing op, so need reset pynative flag. |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
@@ -1431,7 +1468,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
graph_info_map_.clear(); |
|
|
|
op_id_map_.clear(); |
|
|
|
obj_to_forward_id_.clear(); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_p_); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_context_); |
|
|
|
ConfigManager::GetInstance().ResetIterNum(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1509,7 +1546,6 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) |
|
|
|
} |
|
|
|
|
|
|
|
std::string backend = MsContext::GetInstance()->backend_policy(); |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Eval run" << backend; |
|
|
|
BaseRef value = (*run)(arg_list); |
|
|
|
MS_LOG(DEBUG) << "Run end" << value.ToString(); |
|
|
|
|