From 82450afa9eae528a1d2230e78e7cfe253c4f19b0 Mon Sep 17 00:00:00 2001 From: tanghuikang Date: Mon, 7 Dec 2020 09:54:39 +0800 Subject: [PATCH] Optimize memory using in pynative mode --- .../ccsrc/backend/session/ascend_session.cc | 1 + mindspore/ccsrc/backend/session/executor.cc | 14 +++++++ mindspore/ccsrc/backend/session/executor.h | 13 +++++- .../ccsrc/backend/session/session_basic.cc | 11 +++++ .../ccsrc/backend/session/session_basic.h | 3 ++ .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 1 + .../optimizer/irpass/special_op_eliminate.h | 40 +++++++++++-------- .../pipeline/pynative/pynative_execute.cc | 13 ++++++ .../pipeline/pynative/pynative_execute.h | 1 + mindspore/core/ir/anf.h | 4 ++ 10 files changed, 84 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index b680b6b782..1123333b6c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -319,6 +319,7 @@ void HandleOpInputs(const std::set &input_kernel, std::mapToString() << " with " << input_value.first; auto input_value_node = NewValueNode(input_value.first); input_value_node->set_has_new_value(true); + input_value_node->set_used_graph_count(para_ref_size); manager->Replace(paras[i], input_value_node); } } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 00fcef715b..0e67df279d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -298,18 +298,6 @@ class PynativeEliminater : public OptimizerCaller { return out; } - void OnlySaveAbstractInfo(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto &value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - auto tensor = value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - auto new_tensor = std::make_shared(tensor->Dtype()->type_id(), tensor->shape()); - value_node->set_value(MakeValue(new_tensor)); - } - } - public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); @@ -356,11 +344,31 @@ class PynativeEliminater : public OptimizerCaller { // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} PatternNode binop_grad_common; PatternNode getitem_vnode; - PatternNode arg1; - PatternNode arg2; - PatternNode arg3; - PatternNode arg4; + std::vector> args(4); + auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common); + auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); + if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && + CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { + for (size_t i = 0; i < 2; i++) { + auto rep = (args[i]).GetNode(node); + if (rep != nullptr && rep->isa()) { + auto value_node = rep->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto &value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + // when the use count of value node equals to one, it only used in binop_grad_common function + if (value->isa() && value_node->used_graph_count() == 1) { + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + auto new_tensor = std::make_shared(tensor->Dtype()->type_id(), tensor->shape()); + value_node->set_value(new_tensor); + } + } + } + return nullptr; + } // resolve(CommonOPS, getitem)((tensors), 3) + PatternNode arg1; auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); auto pattern2 = PCNode(resolve2, arg, arg1); if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 2543fc9878..c4650b2821 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1039,6 +1039,18 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { } } +void PynativeExecutor::CleanTensorsInValueNode() { + // Only need clean in ms backend policy and session should not be nullptr in ms backend. + if (session == nullptr) { + return; + } + auto useless_tensors = std::make_shared>(); + for (const auto &id_tensor_pair : tensor_id_with_tensor_) { + std::copy(id_tensor_pair.second.begin(), id_tensor_pair.second.end(), std::back_inserter(*useless_tensors)); + } + session->CleanUselessTensors(useless_tensors); +} + AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { auto &out = graph_info_map_[curr_g_].node_map[obj_id]; if (out.second.size() == 1 && out.second[0] == -1) { @@ -2027,6 +2039,7 @@ py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) MS_LOG(DEBUG) << "Eval run " << backend; grad_is_running = true; BaseRef value = (*run)(arg_list); + CleanTensorsInValueNode(); grad_is_running = false; MS_LOG(DEBUG) << "Run end " << value.ToString(); return BaseRefToPyData(value); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index ad89c92c47..7d87ecb234 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -141,6 +141,7 @@ class PynativeExecutor : public std::enable_shared_from_this { // Update the abstract and device address info of value node and tensors in bprop graph void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); void SaveTensorsInValueNode(const ResourcePtr &resource); + void CleanTensorsInValueNode(); // construct grad graph void PushCurrentGraphToStack(); diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 71fab38187..df6cf924fa 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -381,6 +381,9 @@ class ValueNode : public ANode { void set_has_new_value(bool flag) { has_new_value_ = flag; } bool has_new_value() const { return has_new_value_; } + size_t used_graph_count() const { return used_graph_count_; } + void set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; } + std::string ToString() const override; std::string DebugString(int recursive_level = 1) const override; std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } @@ -401,6 +404,7 @@ class ValueNode : public ANode { private: ValuePtr value_; bool has_new_value_ = false; + size_t used_graph_count_{0}; }; template