Browse Source

fix op id issue in pynative

tags/v1.0.0
kingfo 5 years ago
parent
commit
cfda024336
3 changed files with 28 additions and 2 deletions
  1. +1
    -0
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  2. +24
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +3
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 1
- 0
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -302,6 +302,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(input_value.second, input_value.first);
auto input_value_node = NewValueNode(input_value.first);
input_value_node->set_has_new_value(true);
manager->Replace(paras[i], input_value_node);
}
}


+ 24
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -632,6 +632,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
return iter->second;
}
if (!first_grad_step_) {
++op_id_map_[id];
}
return nullptr;
}

@@ -979,7 +982,10 @@ void ClearPyNativeSession() { session = nullptr; }

PynativeExecutor::~PynativeExecutor() { ClearRes(); }

PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
PynativeExecutor::PynativeExecutor() {
grad_flag_ = false;
first_grad_step_ = false;
}

void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
@@ -1000,6 +1006,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
first_grad_step_ = true;
top_graph_cells_.insert(cell_id);
Pushp();
} else {
Pushp();
@@ -1181,7 +1189,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
bool need_replace_param = false;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
need_replace_param = true;
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
@@ -1195,6 +1205,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
for (size_t i = 0; i < params.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);
}
}
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) {
Popp();
@@ -1355,6 +1374,9 @@ void PynativeExecutor::Clear(const std::string &flag) {
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
}
ConfigManager::GetInstance().ResetIterNum();
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
op_forward_map_.clear();
}
return;
}

@@ -1363,6 +1385,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
first_grad_step_ = false;
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
@@ -1374,7 +1397,6 @@ void PynativeExecutor::Clean() {
MS_LOG(DEBUG) << "Clean all res";
Clear();
grad_flag_ = false;
op_forward_map_.clear();
ad::CleanRes();
pipeline::ReclaimOptimizer();
}


+ 3
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -24,6 +24,7 @@
#include <unordered_map>
#include <mutex>
#include <stack>
#include <set>

#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
@@ -145,6 +146,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static ResourcePtr resource_;
static int graph_id_;
bool grad_flag_;
bool first_grad_step_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
@@ -158,6 +160,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
std::set<std::string> top_graph_cells_;
};

using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;


Loading…
Cancel
Save