|
|
@@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsNotNestedGrad() const { |
|
|
bool PynativeExecutor::IsNotNestedGrad() const { |
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_count_; |
|
|
|
|
|
return grad_count_ <= 1; |
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_; |
|
|
|
|
|
return grad_order_ <= 1; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { |
|
|
bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { |
|
|
@@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PynativeExecutor::SubNestedGradCount() { |
|
|
void PynativeExecutor::SubNestedGradCount() { |
|
|
if (grad_count_ > 0) { |
|
|
|
|
|
--grad_count_; |
|
|
|
|
|
|
|
|
if (grad_order_ > 0) { |
|
|
|
|
|
--grad_order_; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string |
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) { |
|
|
void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) { |
|
|
FuncGraphPtr tmp = curr_g_; |
|
|
FuncGraphPtr tmp = curr_g_; |
|
|
if (need_cloned) { |
|
|
|
|
|
|
|
|
if (need_cloned && !IsNotNestedGrad()) { |
|
|
auto cloned_curr_g = BasicClone(curr_g_); |
|
|
auto cloned_curr_g = BasicClone(curr_g_); |
|
|
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_); |
|
|
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_); |
|
|
tmp = cloned_curr_g; |
|
|
tmp = cloned_curr_g; |
|
|
@@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() { |
|
|
void PynativeExecutor::ClearRes() { |
|
|
void PynativeExecutor::ClearRes() { |
|
|
MS_LOG(DEBUG) << "Clear all res"; |
|
|
MS_LOG(DEBUG) << "Clear all res"; |
|
|
Clean(); |
|
|
Clean(); |
|
|
grad_count_ = 0; |
|
|
|
|
|
|
|
|
grad_order_ = 0; |
|
|
grad_flag_ = false; |
|
|
grad_flag_ = false; |
|
|
dynamic_cell_ = false; |
|
|
dynamic_cell_ = false; |
|
|
grad_is_running_ = false; |
|
|
grad_is_running_ = false; |
|
|
|