Browse Source

!11463 Redo delete grad flag for fixing memory not enough

From: @joylvliang
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5c560c7d9b
3 changed files with 49 additions and 21 deletions
  1. +42
    -19
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +3
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  3. +4
    -0
      mindspore/nn/cell.py

+ 42
- 19
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1341,7 +1341,7 @@ py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_poli
break;
}
case kMsBackendMsPrior: {
// use Ms fisrt,use others when ms failed
// use Ms first,use others when ms failed
MS_LOG(INFO) << "RunOp use Ms first backend";
result = RunOpInMs(op_exec_info, status);
if (*status != PYNATIVE_SUCCESS) {
@@ -1557,23 +1557,28 @@ bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) {
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; });
}

std::string PynativeExecutor::GetTopCell(const string &cell_id) {
if (IsTopestGraph(cell_id)) {
return cell_id;
}
std::string top_cell_id;
for (const auto &it : cell_graph_list_) {
if (IsTopestGraph(it->cell_id)) {
top_cell_id = it->cell_id;
TopCellInfoPtr PynativeExecutor::GetTopCell(const string &cell_id, bool find_nearest) {
auto find_top_cell = [&](const string &cell_id) -> TopCellInfoPtr {
auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfoPtr &top_cell) {
return cell_id == top_cell->cell_id && top_cell->is_topest;
});
if (iter != top_cell_list_.end()) {
return *iter;
}
if (it->cell_id == cell_id) {
break;
return nullptr;
};
TopCellInfoPtr top_cell = find_top_cell(cell_id);
// find nearest top cell
if (top_cell == nullptr && find_nearest) {
for (const auto &cell_info : cell_graph_list_) {
MS_EXCEPTION_IF_NULL(cell_info);
top_cell = find_top_cell(cell_info->cell_id);
if (cell_id == cell_info->cell_id) {
break;
}
}
}
if (top_cell_id.empty()) {
MS_LOG(EXCEPTION) << "Get top cell null";
}
return top_cell_id;
return top_cell;
}

void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled) {
@@ -1581,6 +1586,7 @@ void PynativeExecutor::UpdateTopCellInfo(const std::string &cell_id, bool vm_com
[&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; });
if (it != top_cell_list_.end()) {
(*it)->do_vm_compiled = vm_compiled;
(*it)->forward_already_run = false;
if ((*it)->is_topest) {
in_grad_process_ = false;
}
@@ -1704,7 +1710,7 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs
py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
if (comparators_node.empty()) {
MS_LOG(DEBUG) << "Get comparators node falied!";
MS_LOG(DEBUG) << "Get comparators node failed!";
return false;
}
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
@@ -1885,14 +1891,21 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
if (IsTopestGraph(cell_id) && cell_op_info_stack_.empty()) {
CleanPreMemoryInValueNode();
op_index_map_.clear();
top_cell_id_ = cell_id;
in_grad_process_ = true;
auto top_cell = GetTopCell(cell_id);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_id_ = top_cell->cell_id;
top_cell->forward_already_run = true;
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_;
}
if (!in_grad_process_ && cell_op_info_stack_.empty()) {
CleanPreMemoryInValueNode();
op_index_map_.clear();
top_cell_id_ = GetTopCell(cell_id);
in_grad_process_ = true;
auto top_cell = GetTopCell(cell_id, true);
MS_EXCEPTION_IF_NULL(top_cell);
top_cell_id_ = top_cell->cell_id;
top_cell->forward_already_run = true;
MS_LOG(DEBUG) << "Top cell id " << top_cell_id_;
}
PushCurrentCellOpInfoToStack();
@@ -1948,12 +1961,18 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
op_index_map_.clear();
top_cell_id_ = cell_id;
in_grad_process_ = true;
// update forward already run flag with previous top cell
auto pre_top_cell = GetTopCell(cell_id);
if (pre_top_cell != nullptr) {
pre_top_cell->forward_already_run = true;
}
auto df_builder = std::make_shared<FuncGraph>();
auto graph_info = std::make_shared<GraphInfo>(cell_id);
graph_info_map_[df_builder] = graph_info;
auto resource = std::make_shared<pipeline::Resource>();
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
top_cell_info->forward_already_run = true;
top_cell_list_.emplace_back(top_cell_info);
MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
}
@@ -2742,7 +2761,11 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &

py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
const auto &cell_id = GetCellId(cell, args);
bool forward_run = CheckCellGraph(cell_id) && top_cell_id_ == cell_id;
auto top_cell = GetTopCell(cell_id);
bool forward_run = false;
if (top_cell != nullptr) {
forward_run = top_cell->forward_already_run;
}
MS_LOG(DEBUG) << "Graph have already run " << forward_run << " cell id " << cell_id << " top_cell_id_ "
<< top_cell_id_;
return BaseRefToPyData(forward_run);


+ 3
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -62,7 +62,7 @@ void ClearPyNativeSession();
struct GraphInfo {
std::string cell_id;
AnfNodePtr output;
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weigths
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weights
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
std::vector<std::string> objects;
GraphInfo() = default;
@@ -98,6 +98,7 @@ class TopCellInfo {

bool is_topest{false};
bool do_vm_compiled{false};
bool forward_already_run{false};
ResourcePtr resource{nullptr};
FuncGraphPtr df_builder{nullptr};
FuncGraphPtr bg{nullptr}; // Backward graph
@@ -250,7 +251,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args);
std::string GetTopCell(const string &cell_id);
TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args);


+ 4
- 0
mindspore/nn/cell.py View File

@@ -321,10 +321,12 @@ class Cell(Cell_):
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
origin_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs, **kwargs)
for cell in self.cells():
origin_grad.append(cell.requires_grad)
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
@@ -348,6 +350,8 @@ class Cell(Cell_):
output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
return output

def _add_attr(self, name, value):


Loading…
Cancel
Save