|
|
|
@@ -1467,13 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::ClearResidualRes() { |
|
|
|
void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { |
|
|
|
if (top_cell_list_.empty() && !graph_stack_.empty()) { |
|
|
|
graph_id_ = 0; |
|
|
|
graph_info_map_.clear(); |
|
|
|
cell_sw_map_.clear(); |
|
|
|
cell_graph_list_.clear(); |
|
|
|
top_cell_list_.clear(); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_stack_); |
|
|
|
} |
|
|
|
if (dynamic_cell_) { |
|
|
|
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1486,8 +1489,8 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { |
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) { |
|
|
|
return top_cell_list_.back().df_builder; |
|
|
|
} |
|
|
|
if (top_cell_list_.size() < grad_order_) { |
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order"; |
|
|
|
if (top_cell_list_.size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); |
|
|
|
// Grad order greater than 2 |
|
|
|
@@ -1517,8 +1520,8 @@ ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) { |
|
|
|
if (grad_order_ == 0 || grad_order_ == 1) { |
|
|
|
return top_cell_list_.back().resource; |
|
|
|
} |
|
|
|
if (top_cell_list_.size() < grad_order_) { |
|
|
|
MS_LOG(EXCEPTION) << "Get wrong grad order"; |
|
|
|
if (top_cell_list_.size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); |
|
|
|
// Grad order greater than 2 |
|
|
|
@@ -1718,7 +1721,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
// init resource for constructing forward graph and grad graph |
|
|
|
auto g = std::make_shared<FuncGraph>(); |
|
|
|
curr_g_ = g; |
|
|
|
ClearResidualRes(); |
|
|
|
ClearResidualRes(cell_id); |
|
|
|
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { |
|
|
|
MakeNewTopGraph(cell_id, args, g); |
|
|
|
} |
|
|
|
@@ -2030,10 +2033,6 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje |
|
|
|
|
|
|
|
// Set all params(input+weights) |
|
|
|
SetGradGraphParams(df_builder, resource, size); |
|
|
|
// Clone df_builder and resource at first time |
|
|
|
if (CloneDfbuiler(cell_id, df_builder, resource)) { |
|
|
|
df_builder = GetDfbuilder(cell_id); |
|
|
|
} |
|
|
|
// Get params(weights) require derivative |
|
|
|
auto w_args = GetWeightsArgs(weights, df_builder); |
|
|
|
// Get the parameters items and add the value to args_spec |
|
|
|
@@ -2239,26 +2238,6 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args |
|
|
|
return args_spec; |
|
|
|
} |
|
|
|
|
|
|
|
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder, |
|
|
|
const ResourcePtr &resource) { |
|
|
|
bool is_cloned = false; |
|
|
|
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), |
|
|
|
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); |
|
|
|
if (it == top_cell_list_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Get top cell failed"; |
|
|
|
} |
|
|
|
if (it->bg == nullptr) { |
|
|
|
auto cloned_df_newfg = BasicClone(resource->func_graph()); |
|
|
|
it->bg = cloned_df_newfg; |
|
|
|
MS_LOG(DEBUG) << "Cloned df newfg"; |
|
|
|
is_cloned = false; |
|
|
|
} else { |
|
|
|
resource->set_func_graph(it->bg); |
|
|
|
MS_LOG(DEBUG) << "Used cloned df newfg"; |
|
|
|
} |
|
|
|
return is_cloned; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, |
|
|
|
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) { |
|
|
|
FuncGraphPtr top_g = nullptr; |
|
|
|
@@ -2433,28 +2412,6 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg |
|
|
|
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void MapClear(T *map, const std::string &cell_id) { |
|
|
|
for (auto it = map->begin(); it != map->end();) { |
|
|
|
if (it->first.find(cell_id) != std::string::npos) { |
|
|
|
it = map->erase(it); |
|
|
|
} else { |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void VectorClear(T *vec, const std::string &cell_id) { |
|
|
|
for (auto it = vec->begin(); it != vec->end();) { |
|
|
|
if (it->cell_id.find(cell_id) != std::string::npos) { |
|
|
|
it = vec->erase(it); |
|
|
|
} else { |
|
|
|
it++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &cell_id) { |
|
|
|
if (cell_id.empty()) { |
|
|
|
Clean(); |
|
|
|
|