| @@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) { | |||
| void ExecutorPy::ClearRes() { | |||
| MS_LOG(INFO) << "Clean executor resource!"; | |||
| Resource::ClearPrimitivePyPythonObj(); | |||
| Resource::mem_cleaner().ClearPrimitivePyPythonObj(); | |||
| executor_ = nullptr; | |||
| } | |||
| @@ -275,39 +275,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { | |||
| return GetMethodOrAttr(name, type_id, attr_map); | |||
| } | |||
| std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>(); | |||
| void Resource::RecordPrimitivePy(PrimitivePy *prim) { | |||
| MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner(); | |||
| void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| py_objs_[prim] = true; | |||
| all_primitives_[prim] = true; | |||
| } | |||
| void Resource::ErasePrimitivePy(PrimitivePy *prim) { | |||
| void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| auto it = py_objs_.find(prim); | |||
| if (it == py_objs_.end()) { | |||
| auto it = all_primitives_.find(prim); | |||
| if (it == all_primitives_.end()) { | |||
| return; | |||
| } | |||
| // If flag is false,the pointer hased been released, so it can't be visited. | |||
| if (!it->second) { | |||
| return; | |||
| } | |||
| py_objs_[prim] = false; | |||
| all_primitives_[prim] = false; | |||
| prim->SetPyObj(py::none()); | |||
| } | |||
| void Resource::ClearPrimitivePyPythonObj() { | |||
| for (auto &it : py_objs_) { | |||
| void MemoryCleaner::ClearPrimitivePyPythonObj() { | |||
| for (auto &it : all_primitives_) { | |||
| if (it.second) { | |||
| it.first->SetPyObj(py::none()); | |||
| } | |||
| } | |||
| py_objs_.clear(); | |||
| all_primitives_.clear(); | |||
| } | |||
| void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) { | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString(); | |||
| pynative_short_life_primitives_.insert(prim); | |||
| } | |||
| void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) { | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString(); | |||
| ErasePrimitivePy(prim); | |||
| } | |||
| void MemoryCleaner::ClearPynativeShortLifePrimitivePy() { | |||
| for (auto &primitive : pynative_short_life_primitives_) { | |||
| ErasePynativeShortLifePrimitivePy(primitive); | |||
| } | |||
| pynative_short_life_primitives_.clear(); | |||
| } | |||
| void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; } | |||
| void MemoryCleaner::LeavePynativeConstructProcess() { | |||
| pynative_in_construct_process_ = false; | |||
| ClearPynativeShortLifePrimitivePy(); | |||
| } | |||
| bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; } | |||
| void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; } | |||
| void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; } | |||
| bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; } | |||
| void Resource::Clean() { | |||
| // AbstractTensor->elements() will be saved in AbstractBasePtrList | |||
| args_spec_.clear(); | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| @@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap(); | |||
| BuiltInTypeMap &GetAttrMap(); | |||
| class MemoryCleaner { | |||
| public: | |||
| MemoryCleaner() = default; | |||
| ~MemoryCleaner() = default; | |||
| void RecordPrimitivePy(PrimitivePy *prim); | |||
| void ErasePrimitivePy(PrimitivePy *prim); | |||
| void ClearPrimitivePyPythonObj(); | |||
| void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim); | |||
| void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim); | |||
| void ClearPynativeShortLifePrimitivePy(); | |||
| void EnterPynativeConstructProcess(); | |||
| void LeavePynativeConstructProcess(); | |||
| bool IsInPynativeConstructProcess() const; | |||
| void EnterPynativeEndGraphProcess(); | |||
| void LeavePynativeEndGraphProcess(); | |||
| bool IsInPynativeEndGraphProcess() const; | |||
| private: | |||
| std::unordered_map<PrimitivePy *, bool> all_primitives_; | |||
| // PrimitivePy objects that created in pynative construct process.These primitives should be released after construct | |||
| // finished. | |||
| std::unordered_set<PrimitivePy *> pynative_short_life_primitives_; | |||
| bool pynative_in_construct_process_{false}; | |||
| bool pynative_in_end_graph_process_{false}; | |||
| }; | |||
| class Resource : public ResourceBase { | |||
| public: | |||
| explicit Resource(const py::object &obj = py::none()); | |||
| @@ -80,13 +109,11 @@ class Resource : public ResourceBase { | |||
| } | |||
| bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } | |||
| int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } | |||
| static void RecordPrimitivePy(PrimitivePy *prim); | |||
| static void ErasePrimitivePy(PrimitivePy *prim); | |||
| static void ClearPrimitivePyPythonObj(); | |||
| // Reclaim resource and clear the cache. | |||
| // ExecutorPy::Compile() can be called multiple times, so cache | |||
| // should be cleared. | |||
| void Clean(); | |||
| static MemoryCleaner &mem_cleaner() { return mem_cleaner_; } | |||
| private: | |||
| abstract::AnalysisEnginePtr engine_; | |||
| @@ -96,7 +123,8 @@ class Resource : public ResourceBase { | |||
| bool is_cleaned_; | |||
| bool gpu_loopsink_flag_{false}; | |||
| int64_t gpu_loopsink_size_{1}; | |||
| static std::unordered_map<PrimitivePy *, bool> py_objs_; | |||
| // Used to handle mem leak objects. | |||
| static MemoryCleaner mem_cleaner_; | |||
| }; | |||
| using ResourcePtr = std::shared_ptr<pipeline::Resource>; | |||
| @@ -2476,7 +2476,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { | |||
| } | |||
| void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { | |||
| MS_LOG(DEBUG) << "Enter end graph process."; | |||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||
| mem_cleaner.EnterPynativeEndGraphProcess(); | |||
| PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); | |||
| mem_cleaner.LeavePynativeEndGraphProcess(); | |||
| MS_LOG(DEBUG) << "Leave end graph process."; | |||
| } | |||
| void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, | |||
| @@ -2491,6 +2496,24 @@ void PynativeExecutor::Sync() { | |||
| session->SyncStream(); | |||
| } | |||
| void PynativeExecutor::EnterConstruct(const py::object &cell) { | |||
| if (top_cell_ != nullptr) { | |||
| return; | |||
| } | |||
| top_cell_ = cell.ptr(); | |||
| pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess(); | |||
| MS_LOG(DEBUG) << "Enter construct process."; | |||
| } | |||
| void PynativeExecutor::LeaveConstruct(const py::object &cell) { | |||
| if (top_cell_ != cell.ptr()) { | |||
| return; | |||
| } | |||
| top_cell_ = nullptr; | |||
| pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess(); | |||
| MS_LOG(DEBUG) << "Leave construct process."; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | |||
| (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") | |||
| .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") | |||
| @@ -2502,6 +2525,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | |||
| .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") | |||
| .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") | |||
| .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), | |||
| "Executor set grad flag."); | |||
| "Executor set grad flag.") | |||
| .def("enter_construct", &PynativeExecutor::EnterConstruct, | |||
| "Do something before enter construct function.") | |||
| .def("leave_construct", &PynativeExecutor::LeaveConstruct, | |||
| "Do something after leave construct function."); | |||
| })); | |||
| } // namespace mindspore::pynative | |||
| @@ -108,6 +108,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| bool need_replace_forward() const { return need_replace_forward_; } | |||
| bool grad_flag() const { return grad_flag_; } | |||
| void set_grad_flag(bool flag) { grad_flag_ = flag; } | |||
| void EnterConstruct(const py::object &cell); | |||
| void LeaveConstruct(const py::object &cell); | |||
| py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); | |||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args); | |||
| @@ -263,6 +265,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| bool dynamic_cell_{false}; | |||
| bool grad_is_running_{false}; | |||
| bool need_replace_forward_{true}; | |||
| // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, | |||
| // such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global | |||
| // primitives to control memory release. Global primitives are always created in top cell's '__init__' function and | |||
| // temporary primitives are always created in other place.Temporary primitives will be released after executing top | |||
| // cell's 'construct' function but global primitives will not. | |||
| PyObject *top_cell_{nullptr}; | |||
| // Used for construct grad graph | |||
| FuncGraphPtr curr_g_{nullptr}; | |||
| @@ -54,9 +54,18 @@ std::map<std::string, py::object> PrimitivePy::hook_grad_; | |||
| PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) | |||
| : Primitive(name, false), python_obj_(python_obj), signatures_() { | |||
| pipeline::Resource::RecordPrimitivePy(this); | |||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||
| mem_cleaner.RecordPrimitivePy(this); | |||
| if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) { | |||
| mem_cleaner.RecordPynativeShortLifePrimitivePy(this); | |||
| } | |||
| } | |||
| PrimitivePy::~PrimitivePy() { | |||
| // Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in | |||
| // resource. | |||
| pipeline::Resource::mem_cleaner().ErasePrimitivePy(this); | |||
| MS_LOG(DEBUG) << "Release:" << ToString(); | |||
| } | |||
| PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); } | |||
| void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | |||
| void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | |||
| signatures_ = signatures; | |||
| @@ -321,6 +321,12 @@ class _PynativeExecutor: | |||
| def set_grad_flag(self, flag): | |||
| self._executor.set_grad_flag(flag) | |||
| def enter_construct(self, cell): | |||
| self._executor.enter_construct(cell) | |||
| def leave_construct(self, cell): | |||
| self._executor.leave_construct(cell) | |||
| def __call__(self, obj, *args, **kwargs): | |||
| args = args + tuple(kwargs.values()) | |||
| return self._executor(obj, args, "") | |||
| @@ -352,9 +352,13 @@ class Cell(Cell_): | |||
| if not cast_inputs: | |||
| cast_inputs = inputs | |||
| if self.enable_hook: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self._hook_construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| else: | |||
| _pynative_exec.enter_construct(self) | |||
| output = self.construct(*cast_inputs, **kwargs) | |||
| _pynative_exec.leave_construct(self) | |||
| if isinstance(output, Parameter): | |||
| output = output.data | |||
| if self.requires_grad is True: | |||
| @@ -19,7 +19,6 @@ import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.composite import GradOperation | |||
| @@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell): | |||
| super(NetSigmoidGrad, self).__init__() | |||
| self.sigmoid_grad = G.SigmoidGrad() | |||
| @ms_function | |||
| def construct(self, y, dy): | |||
| return self.sigmoid_grad(y, dy) | |||
| @@ -40,7 +38,6 @@ class Grad(nn.Cell): | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| @ms_function | |||
| def construct(self, y, y_grad, dout): | |||
| return self.grad(self.network)(y, y_grad, dout) | |||