Browse Source

Rebase r1.1

Signed-off-by: zjun <zhangjun0@huawei.com>
tags/v1.2.0-rc1
zjun 5 years ago
parent
commit
b76290c155
2 changed files with 156 additions and 71 deletions
  1. +150
    -70
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +6
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 150
- 70
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -640,7 +640,7 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
auto op_name = py::cast<std::string>(args[PY_NAME]);
op_exec_info->op_name = op_name;
if (grad_flag()) {
op_exec_info->op_index = op_name + std::to_string(op_index_map_[op_name]);
op_exec_info->op_index = op_name + "_" + std::to_string(op_index_map_[op_name]);
if (!cell_op_info_stack_.empty()) {
std::string &cell_op_info = cell_op_info_stack_.top();
cell_op_info += op_exec_info->op_index;
@@ -1514,9 +1514,10 @@ std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) {
}
value.emplace_back(str.substr(pre_pos));
};
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) {
return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos;
});
auto it =
std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(), [&key](const CellInfoPtr &value) {
return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos;
});
if (it != cell_graph_list_.end()) {
std::vector<std::string> pre_cell_id;
std::vector<std::string> cur_cell_id;
@@ -1590,16 +1591,19 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com
if (it != top_cell_list_.end()) {
(*it)->do_vm_compiled = vm_compiled;
(*it)->forward_already_run = false;
(*it)->need_grad = true;
if ((*it)->is_topest) {
in_grad_process_ = false;
top_cell_index_ = 0;
}
}
}

bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) {
return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos;
});
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) {
return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos;
});
}

bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); }
@@ -1611,20 +1615,21 @@ void PynativeExecutor::SubNestedGradOrder() {
}

bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) {
return value->cell_id == cell_id && (!is_grad || value->is_grad);
});
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id, is_grad](const CellInfoPtr &value) {
return value->cell_id == cell_id && (!is_grad || value->is_grad);
});
}

bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(),
return std::any_of(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; });
}

bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) {
return value->cell_id == cell_id && value->is_real_dynamic;
});
return std::any_of(
cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_real_dynamic; });
}

void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
@@ -1891,25 +1896,23 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
// check whether cell needed to construct grad graph
if (graph_stack_.empty() && !top_cell_list_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) {
// Clear previous step resource
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) {
auto init_fn = [&](bool flag) {
CleanPreMemoryInValueNode();
op_index_map_.clear();
in_grad_process_ = true;
auto top_cell = GetTopCell(cell_id);
in_bprop_process_ = false;
auto top_cell = GetTopCell(cell_id, flag);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_id_ = top_cell->cell_id;
top_cell_index_ = top_cell->top_cell_index;
top_cell->forward_already_run = true;
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_;
};
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) {
init_fn(false);
}
if (!in_grad_process_ && cell_op_info_stack_.empty()) {
CleanPreMemoryInValueNode();
op_index_map_.clear();
in_grad_process_ = true;
auto top_cell = GetTopCell(cell_id, true);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_id_ = top_cell->cell_id;
top_cell->forward_already_run = true;
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_;
init_fn(true);
}
PushCurrentCellOpInfoToStack();
MS_LOG(INFO) << "NewGraph already compiled";
@@ -1918,8 +1921,12 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
// Init resource for constructing forward graph and grad graph
curr_g_ = std::make_shared<FuncGraph>();
ClearResidualRes(cell_id);
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
MakeNewTopGraph(cell_id, args);
if (graph_stack_.empty()) {
if (IsBpropGraph(cell_id)) {
in_bprop_process_ = true;
} else {
MakeNewTopGraph(cell_id, args);
}
}
PushCurrentGraphToStack();
PushCurrentCellOpInfoToStack();
@@ -1939,6 +1946,14 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
if (!has_dynamic_cell_) {
has_dynamic_cell_ = IsDynamicCell(cell);
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_;
if (has_dynamic_cell_ && IsBpropGraph(cell_id)) {
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[this](const CellInfoPtr &value) { return value->cell_id == top_cell_id_; });
while (it != cell_graph_list_.end()) {
(*it)->is_dynamic = true;
++it;
}
}
}
}

@@ -1976,6 +1991,15 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
top_cell_info->forward_already_run = true;
if (!IsTopestGraph(cell_id)) {
top_cell_info->top_cell_index = cell_graph_list_.size();
top_cell_index_ = top_cell_info->top_cell_index;
} else {
auto top_cell = GetTopCell(cell_id, true);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_info->top_cell_index = top_cell->top_cell_index;
top_cell_index_ = top_cell_info->top_cell_index;
}
top_cell_list_.emplace_back(top_cell_info);
MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
}
@@ -2086,11 +2110,22 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
resource->manager()->AddFuncGraph(curr_g_);
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
FuncGraphPtr newfg = nullptr;
auto top_cell = GetTopCell(top_cell_id_);
MS_EXCEPTION_IF_NULL(top_cell);
// Cell no Change
if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) {
MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad";
top_cell->need_grad = false;
std::unordered_set<AnfNodePtr> node_set;
ClearCnodeRes(curr_g_->output(), &node_set);
node_set.clear();
} else {
MS_LOG(DEBUG) << "Need make ad grad";
if (!top_cell->need_grad) {
std::unordered_set<AnfNodePtr> node_set;
ClearCnodeRes(curr_g_->output(), &node_set);
node_set.clear();
}
newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);
}

@@ -2146,7 +2181,7 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) {
MS_LOG(DEBUG) << "Cell op info is empty";
return true;
}
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) {
return true;
@@ -2162,22 +2197,22 @@ bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) {
}

void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
for (auto &it : cell_graph_list_) {
if (it->cell_id != cell_id) {
it->is_real_dynamic = true;
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
if ((*it)->cell_id != cell_id) {
(*it)->is_real_dynamic = true;
continue;
}
it->is_real_dynamic = true;
(*it)->is_real_dynamic = true;
break;
}
}

void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
// Bprop just save backward graph
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (it != cell_graph_list_.end()) {
(*it)->is_grad = is_grad;
@@ -2196,16 +2231,23 @@ void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGr
<< " cell ops info " << GetCellOpInfo();
auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id);
cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
if (in_bprop_process_) {
cell_graph_list_.emplace_back(cell_info);
} else {
cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info);
}
}
return;
return true;
}
return false;
}

void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad);
if (UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad)) {
return;
}
FuncGraphPtr tmp = g;
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
MS_LOG(DEBUG) << "No need cloned";
@@ -2228,9 +2270,13 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo();
auto cell_info = std::make_shared<CellInfo>(true, has_dynamic_cell_, tmp, cell_id, "");
cell_info->cell_ops_info.emplace_back(GetCellOpInfo());
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
if (in_bprop_process_) {
cell_graph_list_.emplace_back(cell_info);
} else {
cell_graph_list_.insert(cell_graph_list_.begin() + top_cell_index_, cell_info);
}
} else {
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (it != cell_graph_list_.end()) {
(*it)->cell_ops_info.emplace_back(GetCellOpInfo());
@@ -2240,26 +2286,26 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
return;
}

for (auto &it : cell_graph_list_) {
if (it->cell_id != cell_id) {
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
if ((*it)->cell_id != cell_id) {
continue;
}
if (IsFirstGradStep(cell_id)) {
// no compute grad
it->is_grad = is_grad;
(*it)->is_grad = is_grad;
}
if (need_cloned) {
clone_fn();
if (it->fg != nullptr) {
graph_info_map_.erase(it->fg);
if ((*it)->fg != nullptr) {
graph_info_map_.erase((*it)->fg);
}
MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get();
it->fg = tmp;
MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with cloned new " << tmp.get();
(*it)->fg = tmp;
}
if (!need_cloned && !is_grad) {
graph_info_map_.erase(it->fg);
MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get();
it->fg = tmp;
graph_info_map_.erase((*it)->fg);
MS_LOG(DEBUG) << "Update cur graph " << (*it)->fg.get() << " with new " << tmp.get();
(*it)->fg = tmp;
}
break;
}
@@ -2517,9 +2563,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
auto graph_info = std::make_shared<GraphInfo>(cell_id);
graph_info_map_[df_builder] = graph_info;
auto top_cell_info = std::make_shared<TopCellInfo>(false, resource, df_builder, cell_id);
top_cell_info->top_cell_index = top_cell_index_;
top_cell_list_.emplace_back(top_cell_info);
FuncGraphPtr forward_graph = nullptr;
auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto ib = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (ib != cell_graph_list_.end()) {
forward_graph = (*ib)->fg;
@@ -2546,17 +2593,17 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const
const std::string &cell_id) {
std::vector<FuncGraphPtr> graph_before{};
bool index_find = false;
for (const auto &it : cell_graph_list_) {
if (IsBpropGraph(it->cell_id) || it->fg == nullptr) {
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
if (IsBpropGraph((*it)->cell_id) || (*it)->fg == nullptr) {
continue;
}
if (index_find) {
graph_before.emplace_back(it->fg);
graph_before.emplace_back((*it)->fg);
continue;
}
if (it->cell_id == cell_id) {
if ((*it)->cell_id == cell_id) {
index_find = true;
graph_before.emplace_back(it->fg);
graph_before.emplace_back((*it)->fg);
}
}

@@ -2585,6 +2632,7 @@ void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
}
}
graph_before.clear();
}

void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) {
@@ -2674,7 +2722,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
FuncGraphPtr top_g = nullptr;
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto it = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (it != cell_graph_list_.end()) {
top_g = (*it)->fg;
@@ -2711,7 +2759,7 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py:
graph_info_map_.erase(df_builder);
bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id);
bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id;
bool is_topmost = IsTopestGraph(cell_id);
if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) {
return;
}
@@ -2720,16 +2768,32 @@ void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py:
// Clear graph_info_map_
std::vector<std::string> l{};
bool index_find = false;
for (auto &it : cell_graph_list_) {
auto it_end = cell_graph_list_.end();
for (size_t i = 0; i < top_cell_list_.size(); ++i) {
if (top_cell_list_[i]->cell_id == cell_id) {
index_find = true;
continue;
}
if (index_find) {
it_end = cell_graph_list_.begin() + top_cell_list_[i]->top_cell_index;
break;
}
}
index_find = false;
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != it_end; ++it) {
if ((*it)->fg != nullptr) {
std::unordered_set<AnfNodePtr> node_set;
ClearCnodeRes((*it)->fg->output(), &node_set);
node_set.clear();
(*it)->fg = nullptr;
}
if (index_find) {
it->fg = nullptr;
l.emplace_back(it->cell_id);
l.emplace_back((*it)->cell_id);
continue;
}
if (it->cell_id == cell_id) {
if ((*it)->cell_id == cell_id) {
index_find = true;
it->fg = nullptr;
l.emplace_back(it->cell_id);
l.emplace_back((*it)->cell_id);
}
}
for (const auto &it : l) {
@@ -2753,7 +2817,7 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
const auto &cell_id = GetCellId(cell, args);
std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
MS_LOG(DEBUG) << "Key is " << key;
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
for (auto it = cell_graph_list_.begin() + top_cell_index_; it != cell_graph_list_.end(); ++it) {
MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id;
if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) {
continue;
@@ -2773,13 +2837,18 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a
bool forward_run = false;
if (top_cell != nullptr) {
forward_run = top_cell->forward_already_run;
if (forward_run) {
top_cell_index_ = top_cell->top_cell_index;
}
}
MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ "
<< top_cell_id_;
return BaseRefToPyData(forward_run);
}

py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) {
void PynativeExecutor::RunInner(const py::object &cell, const py::tuple &args, const py::object &phase,
py::object *ret) {
MS_EXCEPTION_IF_NULL(ret);
auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "Run start cell id " << cell_id;
bool has_sens = false;
@@ -2814,17 +2883,16 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args,
BaseRef value = (*run)(arg_list);
set_grad_runing(false);
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
auto out = BaseRefToPyData(value);
*ret = BaseRefToPyData(value);
auto do_vm_compiled =
std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; });
if (do_vm_compiled) {
if (MakeBpropNestedCnode(cell, out, cell_id)) {
return out;
if (MakeBpropNestedCnode(cell, *ret, cell_id)) {
return;
}
MakeNestedCnode(cell_id, args, resource, out, has_sens);
MakeNestedCnode(cell_id, args, resource, *ret, has_sens);
}
return out;
}

bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) {
@@ -2883,7 +2951,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id,
std::vector<AnfNodePtr> *inputs) {
FuncGraphPtr forward_graph = nullptr;
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
auto ic = std::find_if(cell_graph_list_.begin() + top_cell_index_, cell_graph_list_.end(),
[&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; });
if (ic != cell_graph_list_.end()) {
forward_graph = (*ic)->fg;
@@ -2950,20 +3018,26 @@ void PynativeExecutor::ClearRes() {
graph_id_ = 0;
grad_order_ = 0;
grad_flag_ = false;
in_grad_process_ = false;
in_bprop_process_ = false;
has_dynamic_cell_ = false;
grad_is_running_ = false;
need_replace_forward_ = true;
curr_g_ = nullptr;

top_cell_id_.clear();
graph_info_map_.clear();
replace_weights_map_.clear();
cell_graph_list_.clear();
top_cell_list_.clear();
cell_input_args_.clear();
op_index_map_.clear();
cell_op_index_with_tensor_id_.clear();
cell_tensor_id_with_tensor_.clear();
prim_abs_list_.clear();
all_value_node_tensors_.clear();
std::stack<FuncGraphPtr>().swap(graph_stack_);
std::stack<std::string>().swap(cell_op_info_stack_);
}

void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
@@ -2984,6 +3058,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
}

py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) {
py::object ret;
PynativeExecutorTry(this, &PynativeExecutor::RunInner, cell, args, phase, &ret);
return ret;
}

void PynativeExecutor::Sync() {
if (session == nullptr) {
MS_EXCEPTION(NotExistsError) << "No session has been created!";


+ 6
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -96,9 +96,11 @@ class TopCellInfo {
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {}

bool need_grad{true};
bool is_topest{false};
bool do_vm_compiled{false};
bool forward_already_run{false};
size_t top_cell_index{0};
ResourcePtr resource{nullptr};
FuncGraphPtr df_builder{nullptr};
FuncGraphPtr bg{nullptr}; // Backward graph
@@ -134,6 +136,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
void RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, py::object *ret);
py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
@@ -241,7 +244,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
bool is_grad);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
@@ -308,8 +311,10 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::mutex instance_lock_;
static int64_t graph_id_;
size_t grad_order_{0};
size_t top_cell_index_{0};
std::string top_cell_id_;
bool grad_flag_{false};
bool in_bprop_process_{false};
bool in_grad_process_{false};
bool has_dynamic_cell_{false};
bool grad_is_running_{false};


Loading…
Cancel
Save