Browse Source

!9868 Fix pynative second order grad memory

From: @zjun3021
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
521b059608
2 changed files with 8 additions and 8 deletions
  1. +6
    -6
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +2
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 6
- 6
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1433,8 +1433,8 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
}

bool PynativeExecutor::IsNotNestedGrad() const {
MS_LOG(DEBUG) << "Grad nested count is " << grad_count_;
return grad_count_ <= 1;
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
return grad_order_ <= 1;
}

bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
@@ -1446,8 +1446,8 @@ bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
}

void PynativeExecutor::SubNestedGradCount() {
if (grad_count_ > 0) {
--grad_count_;
if (grad_order_ > 0) {
--grad_order_;
}
}

@@ -1828,7 +1828,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string

void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) {
FuncGraphPtr tmp = curr_g_;
if (need_cloned) {
if (need_cloned && !IsNotNestedGrad()) {
auto cloned_curr_g = BasicClone(curr_g_);
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_);
tmp = cloned_curr_g;
@@ -2365,7 +2365,7 @@ void PynativeExecutor::Clean() {
void PynativeExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear all res";
Clean();
grad_count_ = 0;
grad_order_ = 0;
grad_flag_ = false;
dynamic_cell_ = false;
grad_is_running_ = false;


+ 2
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -147,7 +147,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void PopGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradCount() { ++grad_count_; }
void AddNestedGradCount() { ++grad_order_; }
void SubNestedGradCount();
bool IsNotNestedGrad() const;
bool IsTopGraph(const std::string &cell_id);
@@ -204,7 +204,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static int64_t graph_id_;
int64_t grad_count_{0};
int64_t grad_order_{0};
bool grad_flag_{false};
bool dynamic_cell_{false};
bool grad_is_running_{false};


Loading…
Cancel
Save