Browse Source

!9868 Fix pynative second order grad memory

From: @zjun3021
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
521b059608
2 changed files with 8 additions and 8 deletions
  1. +6
    -6
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +2
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 6
- 6
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
} }


bool PynativeExecutor::IsNotNestedGrad() const { bool PynativeExecutor::IsNotNestedGrad() const {
MS_LOG(DEBUG) << "Grad nested count is " << grad_count_;
return grad_count_ <= 1;
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
return grad_order_ <= 1;
} }


bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
@@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
} }


void PynativeExecutor::SubNestedGradCount() { void PynativeExecutor::SubNestedGradCount() {
if (grad_count_ > 0) {
--grad_count_;
if (grad_order_ > 0) {
--grad_order_;
} }
} }


@@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string


void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) { void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) {
FuncGraphPtr tmp = curr_g_; FuncGraphPtr tmp = curr_g_;
if (need_cloned) {
if (need_cloned && !IsNotNestedGrad()) {
auto cloned_curr_g = BasicClone(curr_g_); auto cloned_curr_g = BasicClone(curr_g_);
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_); graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_);
tmp = cloned_curr_g; tmp = cloned_curr_g;
@@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() {
void PynativeExecutor::ClearRes() { void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res"; MS_LOG(DEBUG) << "Clear all res";
Clean(); Clean();
grad_count_ = 0;
grad_order_ = 0;
grad_flag_ = false; grad_flag_ = false;
dynamic_cell_ = false; dynamic_cell_ = false;
grad_is_running_ = false; grad_is_running_ = false;


+ 2
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -147,7 +147,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void PopGraphStack(); void PopGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradCount() { ++grad_count_; }
void AddNestedGradCount() { ++grad_order_; }
void SubNestedGradCount(); void SubNestedGradCount();
bool IsNotNestedGrad() const; bool IsNotNestedGrad() const;
bool IsTopGraph(const std::string &cell_id); bool IsTopGraph(const std::string &cell_id);
@@ -204,7 +204,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::shared_ptr<PynativeExecutor> executor_; static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_; static std::mutex instance_lock_;
static int64_t graph_id_; static int64_t graph_id_;
int64_t grad_count_{0};
int64_t grad_order_{0};
bool grad_flag_{false}; bool grad_flag_{false};
bool dynamic_cell_{false}; bool dynamic_cell_{false};
bool grad_is_running_{false}; bool grad_is_running_{false};


Loading…
Cancel
Save