| @@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) { | |||||
| void ExecutorPy::ClearRes() { | void ExecutorPy::ClearRes() { | ||||
| MS_LOG(INFO) << "Clean executor resource!"; | MS_LOG(INFO) << "Clean executor resource!"; | ||||
| Resource::ClearPrimitivePyPythonObj(); | |||||
| Resource::mem_cleaner().ClearPrimitivePyPythonObj(); | |||||
| executor_ = nullptr; | executor_ = nullptr; | ||||
| } | } | ||||
| @@ -275,39 +275,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) { | |||||
| return GetMethodOrAttr(name, type_id, attr_map); | return GetMethodOrAttr(name, type_id, attr_map); | ||||
| } | } | ||||
| std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>(); | |||||
| void Resource::RecordPrimitivePy(PrimitivePy *prim) { | |||||
| MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner(); | |||||
| void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| py_objs_[prim] = true; | |||||
| all_primitives_[prim] = true; | |||||
| } | } | ||||
| void Resource::ErasePrimitivePy(PrimitivePy *prim) { | |||||
| void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) { | |||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto it = py_objs_.find(prim); | |||||
| if (it == py_objs_.end()) { | |||||
| auto it = all_primitives_.find(prim); | |||||
| if (it == all_primitives_.end()) { | |||||
| return; | return; | ||||
| } | } | ||||
| // If flag is false,the pointer hased been released, so it can't be visited. | // If flag is false,the pointer hased been released, so it can't be visited. | ||||
| if (!it->second) { | if (!it->second) { | ||||
| return; | return; | ||||
| } | } | ||||
| py_objs_[prim] = false; | |||||
| all_primitives_[prim] = false; | |||||
| prim->SetPyObj(py::none()); | prim->SetPyObj(py::none()); | ||||
| } | } | ||||
| void Resource::ClearPrimitivePyPythonObj() { | |||||
| for (auto &it : py_objs_) { | |||||
| void MemoryCleaner::ClearPrimitivePyPythonObj() { | |||||
| for (auto &it : all_primitives_) { | |||||
| if (it.second) { | if (it.second) { | ||||
| it.first->SetPyObj(py::none()); | it.first->SetPyObj(py::none()); | ||||
| } | } | ||||
| } | } | ||||
| py_objs_.clear(); | |||||
| all_primitives_.clear(); | |||||
| } | } | ||||
| void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||||
| if (prim == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString(); | |||||
| pynative_short_life_primitives_.insert(prim); | |||||
| } | |||||
| void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||||
| if (prim == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString(); | |||||
| ErasePrimitivePy(prim); | |||||
| } | |||||
| void MemoryCleaner::ClearPynativeShortLifePrimitivePy() { | |||||
| for (auto &primitive : pynative_short_life_primitives_) { | |||||
| ErasePynativeShortLifePrimitivePy(primitive); | |||||
| } | |||||
| pynative_short_life_primitives_.clear(); | |||||
| } | |||||
| void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; } | |||||
| void MemoryCleaner::LeavePynativeConstructProcess() { | |||||
| pynative_in_construct_process_ = false; | |||||
| ClearPynativeShortLifePrimitivePy(); | |||||
| } | |||||
| bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; } | |||||
| void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; } | |||||
| void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; } | |||||
| bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; } | |||||
| void Resource::Clean() { | void Resource::Clean() { | ||||
| // AbstractTensor->elements() will be saved in AbstractBasePtrList | // AbstractTensor->elements() will be saved in AbstractBasePtrList | ||||
| args_spec_.clear(); | args_spec_.clear(); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | |||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind11/stl.h" | #include "pybind11/stl.h" | ||||
| @@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap(); | |||||
| BuiltInTypeMap &GetAttrMap(); | BuiltInTypeMap &GetAttrMap(); | ||||
| class MemoryCleaner { | |||||
| public: | |||||
| MemoryCleaner() = default; | |||||
| ~MemoryCleaner() = default; | |||||
| void RecordPrimitivePy(PrimitivePy *prim); | |||||
| void ErasePrimitivePy(PrimitivePy *prim); | |||||
| void ClearPrimitivePyPythonObj(); | |||||
| void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim); | |||||
| void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim); | |||||
| void ClearPynativeShortLifePrimitivePy(); | |||||
| void EnterPynativeConstructProcess(); | |||||
| void LeavePynativeConstructProcess(); | |||||
| bool IsInPynativeConstructProcess() const; | |||||
| void EnterPynativeEndGraphProcess(); | |||||
| void LeavePynativeEndGraphProcess(); | |||||
| bool IsInPynativeEndGraphProcess() const; | |||||
| private: | |||||
| std::unordered_map<PrimitivePy *, bool> all_primitives_; | |||||
| // PrimitivePy objects that created in pynative construct process.These primitives should be released after construct | |||||
| // finished. | |||||
| std::unordered_set<PrimitivePy *> pynative_short_life_primitives_; | |||||
| bool pynative_in_construct_process_{false}; | |||||
| bool pynative_in_end_graph_process_{false}; | |||||
| }; | |||||
| class Resource : public ResourceBase { | class Resource : public ResourceBase { | ||||
| public: | public: | ||||
| explicit Resource(const py::object &obj = py::none()); | explicit Resource(const py::object &obj = py::none()); | ||||
| @@ -80,13 +109,11 @@ class Resource : public ResourceBase { | |||||
| } | } | ||||
| bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } | bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } | ||||
| int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } | int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } | ||||
| static void RecordPrimitivePy(PrimitivePy *prim); | |||||
| static void ErasePrimitivePy(PrimitivePy *prim); | |||||
| static void ClearPrimitivePyPythonObj(); | |||||
| // Reclaim resource and clear the cache. | // Reclaim resource and clear the cache. | ||||
| // ExecutorPy::Compile() can be called multiple times, so cache | // ExecutorPy::Compile() can be called multiple times, so cache | ||||
| // should be cleared. | // should be cleared. | ||||
| void Clean(); | void Clean(); | ||||
| static MemoryCleaner &mem_cleaner() { return mem_cleaner_; } | |||||
| private: | private: | ||||
| abstract::AnalysisEnginePtr engine_; | abstract::AnalysisEnginePtr engine_; | ||||
| @@ -96,7 +123,8 @@ class Resource : public ResourceBase { | |||||
| bool is_cleaned_; | bool is_cleaned_; | ||||
| bool gpu_loopsink_flag_{false}; | bool gpu_loopsink_flag_{false}; | ||||
| int64_t gpu_loopsink_size_{1}; | int64_t gpu_loopsink_size_{1}; | ||||
| static std::unordered_map<PrimitivePy *, bool> py_objs_; | |||||
| // Used to handle mem leak objects. | |||||
| static MemoryCleaner mem_cleaner_; | |||||
| }; | }; | ||||
| using ResourcePtr = std::shared_ptr<pipeline::Resource>; | using ResourcePtr = std::shared_ptr<pipeline::Resource>; | ||||
| @@ -2476,7 +2476,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { | |||||
| } | } | ||||
| void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { | void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { | ||||
| MS_LOG(DEBUG) << "Enter end graph process."; | |||||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||||
| mem_cleaner.EnterPynativeEndGraphProcess(); | |||||
| PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); | PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); | ||||
| mem_cleaner.LeavePynativeEndGraphProcess(); | |||||
| MS_LOG(DEBUG) << "Leave end graph process."; | |||||
| } | } | ||||
| void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, | void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, | ||||
| @@ -2491,6 +2496,24 @@ void PynativeExecutor::Sync() { | |||||
| session->SyncStream(); | session->SyncStream(); | ||||
| } | } | ||||
| void PynativeExecutor::EnterConstruct(const py::object &cell) { | |||||
| if (top_cell_ != nullptr) { | |||||
| return; | |||||
| } | |||||
| top_cell_ = cell.ptr(); | |||||
| pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess(); | |||||
| MS_LOG(DEBUG) << "Enter construct process."; | |||||
| } | |||||
| void PynativeExecutor::LeaveConstruct(const py::object &cell) { | |||||
| if (top_cell_ != cell.ptr()) { | |||||
| return; | |||||
| } | |||||
| top_cell_ = nullptr; | |||||
| pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess(); | |||||
| MS_LOG(DEBUG) << "Leave construct process."; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | ||||
| (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") | (void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_") | ||||
| .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") | .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") | ||||
| @@ -2502,6 +2525,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { | |||||
| .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") | .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") | ||||
| .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") | .def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") | ||||
| .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), | .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), | ||||
| "Executor set grad flag."); | |||||
| "Executor set grad flag.") | |||||
| .def("enter_construct", &PynativeExecutor::EnterConstruct, | |||||
| "Do something before enter construct function.") | |||||
| .def("leave_construct", &PynativeExecutor::LeaveConstruct, | |||||
| "Do something after leave construct function."); | |||||
| })); | })); | ||||
| } // namespace mindspore::pynative | } // namespace mindspore::pynative | ||||
| @@ -108,6 +108,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| bool need_replace_forward() const { return need_replace_forward_; } | bool need_replace_forward() const { return need_replace_forward_; } | ||||
| bool grad_flag() const { return grad_flag_; } | bool grad_flag() const { return grad_flag_; } | ||||
| void set_grad_flag(bool flag) { grad_flag_ = flag; } | void set_grad_flag(bool flag) { grad_flag_ = flag; } | ||||
| void EnterConstruct(const py::object &cell); | |||||
| void LeaveConstruct(const py::object &cell); | |||||
| py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); | py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); | ||||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args); | OpExecInfoPtr GenerateOpExecInfo(const py::args &args); | ||||
| @@ -263,6 +265,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| bool dynamic_cell_{false}; | bool dynamic_cell_{false}; | ||||
| bool grad_is_running_{false}; | bool grad_is_running_{false}; | ||||
| bool need_replace_forward_{true}; | bool need_replace_forward_{true}; | ||||
| // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, | |||||
| // such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global | |||||
| // primitives to control memory release. Global primitives are always created in top cell's '__init__' function and | |||||
| // temporary primitives are always created in other place.Temporary primitives will be released after executing top | |||||
| // cell's 'construct' function but global primitives will not. | |||||
| PyObject *top_cell_{nullptr}; | |||||
| // Used for construct grad graph | // Used for construct grad graph | ||||
| FuncGraphPtr curr_g_{nullptr}; | FuncGraphPtr curr_g_{nullptr}; | ||||
| @@ -54,9 +54,18 @@ std::map<std::string, py::object> PrimitivePy::hook_grad_; | |||||
| PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) | PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) | ||||
| : Primitive(name, false), python_obj_(python_obj), signatures_() { | : Primitive(name, false), python_obj_(python_obj), signatures_() { | ||||
| pipeline::Resource::RecordPrimitivePy(this); | |||||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||||
| mem_cleaner.RecordPrimitivePy(this); | |||||
| if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) { | |||||
| mem_cleaner.RecordPynativeShortLifePrimitivePy(this); | |||||
| } | |||||
| } | |||||
| PrimitivePy::~PrimitivePy() { | |||||
| // Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in | |||||
| // resource. | |||||
| pipeline::Resource::mem_cleaner().ErasePrimitivePy(this); | |||||
| MS_LOG(DEBUG) << "Release:" << ToString(); | |||||
| } | } | ||||
| PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); } | |||||
| void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | ||||
| void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | ||||
| signatures_ = signatures; | signatures_ = signatures; | ||||
| @@ -321,6 +321,12 @@ class _PynativeExecutor: | |||||
| def set_grad_flag(self, flag): | def set_grad_flag(self, flag): | ||||
| self._executor.set_grad_flag(flag) | self._executor.set_grad_flag(flag) | ||||
| def enter_construct(self, cell): | |||||
| self._executor.enter_construct(cell) | |||||
| def leave_construct(self, cell): | |||||
| self._executor.leave_construct(cell) | |||||
| def __call__(self, obj, *args, **kwargs): | def __call__(self, obj, *args, **kwargs): | ||||
| args = args + tuple(kwargs.values()) | args = args + tuple(kwargs.values()) | ||||
| return self._executor(obj, args, "") | return self._executor(obj, args, "") | ||||
| @@ -352,9 +352,13 @@ class Cell(Cell_): | |||||
| if not cast_inputs: | if not cast_inputs: | ||||
| cast_inputs = inputs | cast_inputs = inputs | ||||
| if self.enable_hook: | if self.enable_hook: | ||||
| _pynative_exec.enter_construct(self) | |||||
| output = self._hook_construct(*cast_inputs, **kwargs) | output = self._hook_construct(*cast_inputs, **kwargs) | ||||
| _pynative_exec.leave_construct(self) | |||||
| else: | else: | ||||
| _pynative_exec.enter_construct(self) | |||||
| output = self.construct(*cast_inputs, **kwargs) | output = self.construct(*cast_inputs, **kwargs) | ||||
| _pynative_exec.leave_construct(self) | |||||
| if isinstance(output, Parameter): | if isinstance(output, Parameter): | ||||
| output = output.data | output = output.data | ||||
| if self.requires_grad is True: | if self.requires_grad is True: | ||||
| @@ -19,7 +19,6 @@ import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore.ops.composite import GradOperation | from mindspore.ops.composite import GradOperation | ||||
| @@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell): | |||||
| super(NetSigmoidGrad, self).__init__() | super(NetSigmoidGrad, self).__init__() | ||||
| self.sigmoid_grad = G.SigmoidGrad() | self.sigmoid_grad = G.SigmoidGrad() | ||||
| @ms_function | |||||
| def construct(self, y, dy): | def construct(self, y, dy): | ||||
| return self.sigmoid_grad(y, dy) | return self.sigmoid_grad(y, dy) | ||||
| @@ -40,7 +38,6 @@ class Grad(nn.Cell): | |||||
| self.grad = GradOperation(get_all=True, sens_param=True) | self.grad = GradOperation(get_all=True, sens_param=True) | ||||
| self.network = network | self.network = network | ||||
| @ms_function | |||||
| def construct(self, y, y_grad, dout): | def construct(self, y, y_grad, dout): | ||||
| return self.grad(self.network)(y, y_grad, dout) | return self.grad(self.network)(y, y_grad, dout) | ||||