|
|
|
@@ -632,6 +632,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { |
|
|
|
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
if (!first_grad_step_) { |
|
|
|
++op_id_map_[id]; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -979,7 +982,10 @@ void ClearPyNativeSession() { session = nullptr; } |
|
|
|
|
|
|
|
PynativeExecutor::~PynativeExecutor() { ClearRes(); } |
|
|
|
|
|
|
|
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } |
|
|
|
PynativeExecutor::PynativeExecutor() { |
|
|
|
grad_flag_ = false; |
|
|
|
first_grad_step_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
@@ -1000,6 +1006,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
cell_resource_map_[cell_id] = resource_; |
|
|
|
df_builder_ = std::make_shared<FuncGraph>(); |
|
|
|
MS_LOG(DEBUG) << "First new graph" << top_g_.get(); |
|
|
|
first_grad_step_ = true; |
|
|
|
top_graph_cells_.insert(cell_id); |
|
|
|
Pushp(); |
|
|
|
} else { |
|
|
|
Pushp(); |
|
|
|
@@ -1181,7 +1189,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); |
|
|
|
resource_->manager()->AddFuncGraph(curr_g_); |
|
|
|
// custom bprop debug |
|
|
|
bool need_replace_param = false; |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
need_replace_param = true; |
|
|
|
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size(); |
|
|
|
if (par_number > 0) { |
|
|
|
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number |
|
|
|
@@ -1195,6 +1205,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
} |
|
|
|
} |
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); |
|
|
|
if (need_replace_param) { |
|
|
|
auto params = newfg->parameters(); |
|
|
|
auto manager = Manage({newfg}, false); |
|
|
|
for (size_t i = 0; i < params.size(); i++) { |
|
|
|
ValuePtr value = PyAttrValue(args[i]); |
|
|
|
auto v_node = NewValueNode(value); |
|
|
|
manager->Replace(params[i], v_node); |
|
|
|
} |
|
|
|
} |
|
|
|
graph_info_map_.erase(curr_g_); |
|
|
|
if (curr_g_ != top_g_) { |
|
|
|
Popp(); |
|
|
|
@@ -1355,6 +1374,9 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); |
|
|
|
} |
|
|
|
ConfigManager::GetInstance().ResetIterNum(); |
|
|
|
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) { |
|
|
|
op_forward_map_.clear(); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1363,6 +1385,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
top_g_ = nullptr; |
|
|
|
df_builder_ = nullptr; |
|
|
|
curr_g_ = nullptr; |
|
|
|
first_grad_step_ = false; |
|
|
|
graph_info_map_.clear(); |
|
|
|
op_id_map_.clear(); |
|
|
|
obj_to_forward_id_.clear(); |
|
|
|
@@ -1374,7 +1397,6 @@ void PynativeExecutor::Clean() { |
|
|
|
MS_LOG(DEBUG) << "Clean all res"; |
|
|
|
Clear(); |
|
|
|
grad_flag_ = false; |
|
|
|
op_forward_map_.clear(); |
|
|
|
ad::CleanRes(); |
|
|
|
pipeline::ReclaimOptimizer(); |
|
|
|
} |
|
|
|
|