From: @zhangzhaoju Reviewed-by: @zh_qh,@hwhewei,@zh_qh Signed-off-by: @zh_qh,@zh_qhpull/15216/MERGE
| @@ -1997,20 +1997,25 @@ class IrParser { | |||
| // restore python function of PrimitivePy from serialized file | |||
| py::object py_obj = LoadObject(lexer_.GetTokenText()); | |||
| PrimitivePyPtr ptr = nullptr; | |||
| py::object py_adapter = py_obj; | |||
| static auto len = strlen("PrimitivePy::"); | |||
| bool cloned = false; | |||
| if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) { | |||
| auto clone_fn = py_obj.attr("_clone"); | |||
| py::object new_obj = clone_fn(); | |||
| ptr = new_obj.cast<PrimitivePyPtr>(); | |||
| if (ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error"; | |||
| } | |||
| } else { | |||
| auto len = strlen("PrimitivePy::"); | |||
| if (id.size() < len) { | |||
| return TOK_ERROR; | |||
| } | |||
| ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj); | |||
| py_adapter = clone_fn(); | |||
| cloned = true; | |||
| } else if (id.size() < len) { | |||
| return TOK_ERROR; | |||
| } | |||
| auto prim_adapter = py_adapter.cast<PrimitivePyAdapterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_adapter); | |||
| if (!cloned) { | |||
| prim_adapter->set_name(id.substr(len)); | |||
| } | |||
| PrimitivePyPtr ptr = prim_adapter->attached_primitive(); | |||
| if (ptr == nullptr) { | |||
| ptr = std::make_shared<PrimitivePy>(py_adapter, prim_adapter); | |||
| prim_adapter->set_attached_primitive(ptr); | |||
| } | |||
| *val_ptr = ptr; | |||
| @@ -366,8 +366,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res | |||
| auto func_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object()); | |||
| auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut"); | |||
| bprop_cut->CopyHookFunction(prim); | |||
| auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| @@ -157,12 +157,10 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>()); | |||
| (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr()) | |||
| .def(py::init<vector<PrimitivePyPtr>, string>()) | |||
| .def(py::init<vector<string>, string>()); | |||
| .def(py::init<vector<py::object>, string>()); | |||
| (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_") | |||
| .def(py::init<PatternPtr, vector<PatternPtr>>()) | |||
| .def(py::init<PrimitivePyPtr, vector<PatternPtr>>()) | |||
| .def(py::init<string, vector<PatternPtr>>()); | |||
| .def(py::init<py::object, vector<PatternPtr>>()); | |||
| (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>()); | |||
| (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_") | |||
| @@ -87,16 +87,18 @@ class Prim : public Pattern { | |||
| public: | |||
| Prim() { unique_name_ = std::to_string(g_id_++); } | |||
| ~Prim() = default; | |||
| Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) { | |||
| Prim(vector<py::object> prim_objs, string name) : name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "Prim_" + name; | |||
| // Default using the first prim to build target | |||
| matched_prim_ = primitives_[0]; | |||
| } | |||
| Prim(vector<string> types, string name) : types_(types), name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "Prim_" + name; | |||
| // Make primitives_ | |||
| for (auto &iter : types) { | |||
| primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr))); | |||
| for (auto &prim_obj : prim_objs) { | |||
| if (py::isinstance<PrimitivePyAdapter>(prim_obj)) { | |||
| auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>(); | |||
| primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter)); | |||
| } else if (py::isinstance<py::str>(prim_obj)) { | |||
| std::string prim_name = prim_obj.cast<py::str>(); | |||
| primitives_.push_back(std::make_shared<PrimitivePy>(prim_name)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input."; | |||
| } | |||
| } | |||
| // Default using the first prim to build target | |||
| matched_prim_ = primitives_[0]; | |||
| @@ -111,7 +113,6 @@ class Prim : public Pattern { | |||
| } | |||
| private: | |||
| vector<string> types_; | |||
| vector<PrimitivePyPtr> primitives_; | |||
| string name_; | |||
| PrimitivePyPtr matched_prim_{nullptr}; | |||
| @@ -127,16 +128,19 @@ class Call : public Pattern { | |||
| unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name(); | |||
| inputs_ = inputs; | |||
| } | |||
| Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) { | |||
| prim_ = prim; | |||
| Call(py::object prim_obj, vector<PatternPtr> inputs) { | |||
| if (py::isinstance<PrimitivePyAdapter>(prim_obj)) { | |||
| auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>(); | |||
| prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter); | |||
| } else if (py::isinstance<py::str>(prim_obj)) { | |||
| std::string prim_name = prim_obj.cast<py::str>(); | |||
| prim_ = std::make_shared<PrimitivePy>(prim_name); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input."; | |||
| } | |||
| unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| } | |||
| Call(string prim_str, vector<PatternPtr> inputs) { | |||
| prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr)); | |||
| unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| } | |||
| MS_DECLARE_PARENT(Call, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| PrimitivePtr prim_value() { return prim_; } | |||
| @@ -119,7 +119,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) { | |||
| auto bprop_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object()); | |||
| auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut"); | |||
| fake_bprop->set_hook(bprop_func); | |||
| (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); | |||
| outputs.push_back(NewValueNode(fake_bprop)); | |||
| @@ -236,16 +236,21 @@ ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) { | |||
| // desc has format "<class xxxx>", strip the '<' and '>' by offset 1; | |||
| return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1)); | |||
| } | |||
| auto primitive = obj.cast<PrimitivePyPtr>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; | |||
| return nullptr; | |||
| py::object adapter_obj = obj; | |||
| if (py::hasattr(obj, "__setattr_flag__")) { | |||
| if (py::hasattr(obj, "_clone")) { | |||
| auto clone_fn = obj.attr("_clone"); | |||
| adapter_obj = clone_fn(); | |||
| } | |||
| } | |||
| if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) { | |||
| auto clone_fn = obj.attr("_clone"); | |||
| py::object new_obj = clone_fn(); | |||
| primitive = new_obj.cast<PrimitivePyPtr>(); | |||
| auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_adapter); | |||
| auto primitive = prim_adapter->attached_primitive(); | |||
| if (primitive == nullptr) { | |||
| primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter); | |||
| prim_adapter->set_attached_primitive(primitive); | |||
| } | |||
| if (use_signature) { | |||
| return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive); | |||
| } | |||
| @@ -371,7 +371,6 @@ void ExecutorPy::DelNetRes(const std::string &id) { | |||
| void ExecutorPy::ClearRes() { | |||
| MS_LOG(INFO) << "Clean executor resource!"; | |||
| Resource::mem_cleaner().ClearPrimitivePyPythonObj(); | |||
| #ifdef ENABLE_DUMP_IR | |||
| mindspore::RDR::ClearAll(); | |||
| #endif | |||
| @@ -384,7 +383,7 @@ ExecutorPy::~ExecutorPy() { | |||
| } | |||
| void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node, | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) { | |||
| std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) { | |||
| std::string weight_name; | |||
| auto x = root_node->input(1); | |||
| if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) { | |||
| @@ -437,15 +436,15 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig | |||
| return; | |||
| } | |||
| auto quant_op = quant_op_value->cast<PrimitivePyPtr>(); | |||
| (*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); | |||
| (*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name); | |||
| } | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport( | |||
| std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> ExecutorPy::FetchInfoForQuantExport( | |||
| const std::string &phase_s) { | |||
| FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table; | |||
| std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> fake_quant_table; | |||
| auto filter = [](const AnfNodePtr &node) { | |||
| return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || | |||
| IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative)); | |||
| @@ -472,7 +471,6 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI | |||
| } | |||
| GetWeightInfo(root_node, weight_node, &fake_quant_table); | |||
| } | |||
| return fake_quant_table; | |||
| } | |||
| @@ -1162,9 +1160,6 @@ void StartUpProfiling() { | |||
| } | |||
| void InitPipeline() { | |||
| // If previous pipeline exit with exception, memory cleaner's flags maybe unpredictable, so init when a new pipeline | |||
| // start. | |||
| pipeline::Resource::mem_cleaner().Init(); | |||
| // set python env flag | |||
| mindspore::parse::python_adapter::set_python_env_flag(true); | |||
| // Startup profiling before open tsd | |||
| @@ -105,13 +105,14 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||
| static void DebugTerminate(bool val) { debugger_terminate_ = val; } | |||
| void TerminateDebugger(); | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s); | |||
| std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> FetchInfoForQuantExport( | |||
| const std::string &phase_s); | |||
| private: | |||
| ExecutorPy(); | |||
| void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors); | |||
| void GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node, | |||
| std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table); | |||
| std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table); | |||
| void GetGeBackendPolicy() const; | |||
| // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after | |||
| // 'validate' stage | |||
| @@ -308,91 +308,5 @@ void Resource::Clean() { | |||
| is_cleaned_ = true; | |||
| } | |||
| void MemoryCleaner::Init() { | |||
| pynative_in_construct_process_ = false; | |||
| pynative_in_end_graph_process_ = false; | |||
| pynative_released_history_.clear(); | |||
| pynative_new_primtives_squence_.clear(); | |||
| } | |||
| MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner(); | |||
| void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| all_primitives_[prim] = true; | |||
| } | |||
| void MemoryCleaner::ReleasePrimitivePyObj(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| auto it = all_primitives_.find(prim); | |||
| if (it == all_primitives_.end()) { | |||
| return; | |||
| } | |||
| // If flag is false,the pointer hased been released, so it can't be visited. | |||
| if (!it->second) { | |||
| return; | |||
| } | |||
| all_primitives_[prim] = false; | |||
| prim->SetPyObj(py::none()); | |||
| } | |||
| void MemoryCleaner::ClearPrimitivePyPythonObj() { | |||
| for (auto &it : all_primitives_) { | |||
| if (it.second) { | |||
| it.first->SetPyObj(py::none()); | |||
| } | |||
| } | |||
| all_primitives_.clear(); | |||
| } | |||
| void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) { | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString(); | |||
| pynative_short_life_primitives_.insert(prim); | |||
| pynative_new_primtives_squence_.push_back(prim->ToString()); | |||
| } | |||
| void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) { | |||
| if (prim == nullptr) { | |||
| return; | |||
| } | |||
| if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) { | |||
| return; | |||
| } | |||
| pynative_short_life_primitives_.erase(prim); | |||
| MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString(); | |||
| } | |||
| void MemoryCleaner::ClearPynativeShortLifePrimitivePy() { | |||
| // If the primitives name sequence never been released before, keep the primtives alive | |||
| if (std::find(pynative_released_history_.begin(), pynative_released_history_.end(), | |||
| pynative_new_primtives_squence_) == pynative_released_history_.end()) { | |||
| pynative_released_history_.push_back(pynative_new_primtives_squence_); | |||
| } else { | |||
| for (auto &primitive : pynative_short_life_primitives_) { | |||
| ReleasePrimitivePyObj(primitive); | |||
| } | |||
| } | |||
| pynative_short_life_primitives_.clear(); | |||
| pynative_new_primtives_squence_.clear(); | |||
| } | |||
| void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; } | |||
| void MemoryCleaner::LeavePynativeConstructProcess() { | |||
| pynative_in_construct_process_ = false; | |||
| ClearPynativeShortLifePrimitivePy(); | |||
| } | |||
| bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; } | |||
| void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; } | |||
| void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; } | |||
| bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; } | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -53,39 +53,6 @@ BuiltInTypeMap &GetMethodMap(); | |||
| BuiltInTypeMap &GetAttrMap(); | |||
| class MemoryCleaner { | |||
| public: | |||
| MemoryCleaner() = default; | |||
| ~MemoryCleaner() = default; | |||
| void Init(); | |||
| void RecordPrimitivePy(PrimitivePy *prim); | |||
| void ReleasePrimitivePyObj(PrimitivePy *prim); | |||
| void ClearPrimitivePyPythonObj(); | |||
| void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim); | |||
| void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim); | |||
| void ClearPynativeShortLifePrimitivePy(); | |||
| void EnterPynativeConstructProcess(); | |||
| void LeavePynativeConstructProcess(); | |||
| bool IsInPynativeConstructProcess() const; | |||
| void EnterPynativeEndGraphProcess(); | |||
| void LeavePynativeEndGraphProcess(); | |||
| bool IsInPynativeEndGraphProcess() const; | |||
| private: | |||
| std::unordered_map<PrimitivePy *, bool> all_primitives_; | |||
| // PrimitivePy objects that created in pynative construct process.These primitives should be released after construct | |||
| // finished. | |||
| std::unordered_set<PrimitivePy *> pynative_short_life_primitives_; | |||
| // Sequence of primtive names in one construct process. | |||
| std::vector<std::string> pynative_new_primtives_squence_; | |||
| std::vector<std::vector<std::string>> pynative_released_history_; | |||
| bool pynative_in_construct_process_{false}; | |||
| bool pynative_in_end_graph_process_{false}; | |||
| }; | |||
| class Resource : public ResourceBase { | |||
| public: | |||
| explicit Resource(const py::object &obj = py::none()); | |||
| @@ -118,7 +85,6 @@ class Resource : public ResourceBase { | |||
| // ExecutorPy::Compile() can be called multiple times, so cache | |||
| // should be cleared. | |||
| void Clean(); | |||
| static MemoryCleaner &mem_cleaner() { return mem_cleaner_; } | |||
| private: | |||
| abstract::AnalysisEnginePtr engine_; | |||
| @@ -128,8 +94,6 @@ class Resource : public ResourceBase { | |||
| bool is_cleaned_; | |||
| bool gpu_loopsink_flag_{false}; | |||
| int64_t gpu_loopsink_size_{1}; | |||
| // Used to handle mem leak objects. | |||
| static MemoryCleaner mem_cleaner_; | |||
| }; | |||
| using ResourcePtr = std::shared_ptr<pipeline::Resource>; | |||
| @@ -56,7 +56,6 @@ struct OpExecInfo { | |||
| AbstractBasePtr abstract; | |||
| py::list op_inputs; | |||
| py::dict op_attrs; | |||
| std::vector<int64_t> inputs_mask; | |||
| bool is_dynamic_shape = false; | |||
| std::string next_op_name = ""; | |||
| @@ -689,13 +689,18 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) { | |||
| } | |||
| grad()->op_index_map()[op_name]++; | |||
| } | |||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]); | |||
| MS_EXCEPTION_IF_NULL(adapter); | |||
| auto prim = adapter->attached_primitive(); | |||
| if (prim == nullptr) { | |||
| prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter); | |||
| adapter->set_attached_primitive(prim); | |||
| } | |||
| if (!prim->HasPyObj()) { | |||
| MS_LOG(EXCEPTION) << "Pyobj is empty"; | |||
| } | |||
| op_exec_info->py_primitive = prim; | |||
| op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); | |||
| op_exec_info->op_inputs = args[PY_INPUTS]; | |||
| return op_exec_info; | |||
| } | |||
| @@ -3264,10 +3269,7 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { | |||
| void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { | |||
| MS_LOG(DEBUG) << "Enter end graph process."; | |||
| py::object *ret = nullptr; | |||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||
| mem_cleaner.EnterPynativeEndGraphProcess(); | |||
| PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args); | |||
| mem_cleaner.LeavePynativeEndGraphProcess(); | |||
| MS_LOG(DEBUG) << "Leave end graph process."; | |||
| } | |||
| @@ -3289,7 +3291,6 @@ void PynativeExecutor::EnterConstruct(const py::object &cell) { | |||
| return; | |||
| } | |||
| py_top_cell_ = cell.ptr(); | |||
| pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess(); | |||
| MS_LOG(DEBUG) << "Enter construct process."; | |||
| } | |||
| @@ -3298,7 +3299,6 @@ void PynativeExecutor::LeaveConstruct(const py::object &cell) { | |||
| return; | |||
| } | |||
| py_top_cell_ = nullptr; | |||
| pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess(); | |||
| MS_LOG(DEBUG) << "Leave construct process."; | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #include <mutex> | |||
| #include <map> | |||
| #include <utility> | |||
| #include "ir/signature.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| @@ -57,22 +58,21 @@ void SyncData(const py::object &arg) { | |||
| } // namespace | |||
| std::map<std::string, py::object> PrimitivePy::hook_grad_; | |||
| PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj) | |||
| : Primitive(name, false), python_obj_(python_obj), signatures_() { | |||
| auto &mem_cleaner = pipeline::Resource::mem_cleaner(); | |||
| mem_cleaner.RecordPrimitivePy(this); | |||
| MS_LOG(DEBUG) << "New primitive:" << name; | |||
| if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) { | |||
| mem_cleaner.RecordPynativeShortLifePrimitivePy(this); | |||
| } | |||
| } | |||
| PrimitivePy::~PrimitivePy() { | |||
| // Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in | |||
| // resource. | |||
| pipeline::Resource::mem_cleaner().ReleasePrimitivePyObj(this); | |||
| MS_LOG(DEBUG) << "Release:" << ToString(); | |||
| PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {} | |||
| PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter) | |||
| : Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) { | |||
| MS_LOG(DEBUG) << "New primitive:" << adapter->name_; | |||
| set_signatures(adapter->signatures_); | |||
| Primitive::SetAttrs(adapter->attrs_); | |||
| Primitive::set_prim_type(adapter->prim_type_); | |||
| Primitive::set_const_prim(adapter->is_const_prim_); | |||
| Primitive::set_const_input_indexes(adapter->const_input_indexes_); | |||
| set_hook(adapter->hook_); | |||
| set_instance_name(adapter->instance_name_); | |||
| } | |||
| void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | |||
| PrimitivePy::~PrimitivePy() { MS_LOG(DEBUG) << "Release:" << ToString(); } | |||
| void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | |||
| signatures_ = signatures; | |||
| set_has_signature(!signatures.empty()); | |||
| @@ -272,29 +272,6 @@ py::function PrimitivePy::GetComputeFunction() const { | |||
| return vm_fn; | |||
| } | |||
| void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; | |||
| } | |||
| bool converted = parse::ConvertData(obj, &converted_ret); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||
| } | |||
| if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { | |||
| attr_name = kOpAttrNameReplaceMap[attr_name]; | |||
| } | |||
| const std::string &prim_name = this->name(); | |||
| CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret); | |||
| (void)this->AddAttr(attr_name, converted_ret); | |||
| } | |||
| void PrimitivePy::DelPyAttr(const py::str &name) { | |||
| std::string attr_name = name; | |||
| (void)this->DelAttr(attr_name); | |||
| } | |||
| py::dict PrimitivePy::GetAttrDict() { | |||
| py::dict attr_dict; | |||
| for (auto &attr : attrs_) { | |||
| @@ -338,9 +315,11 @@ bool PrimitivePy::HasComputeFunction() const { | |||
| PrimitivePtr PrimitivePy::Clone() { | |||
| auto clone_fn = python_obj_.attr("_clone"); | |||
| py::object new_obj = clone_fn(); | |||
| auto cloned_prim = new_obj.cast<PrimitivePyPtr>(); | |||
| return cloned_prim; | |||
| py::object obj_adapter = clone_fn(); | |||
| auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>(); | |||
| auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter); | |||
| prim_adapter->set_attached_primitive(prim); | |||
| return prim; | |||
| } | |||
| py::dict PrimitivePy::RunInfer(const py::tuple &args) { | |||
| @@ -379,6 +358,113 @@ py::object PrimitivePy::RunInferValue(const py::tuple &args) { | |||
| return infer_value(*args); | |||
| } | |||
| PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {} | |||
| void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; | |||
| } | |||
| bool converted = parse::ConvertData(obj, &converted_ret); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||
| } | |||
| if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { | |||
| attr_name = kOpAttrNameReplaceMap[attr_name]; | |||
| } | |||
| CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret); | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->AddAttr(attr_name, converted_ret); | |||
| } else { | |||
| attrs_[attr_name] = converted_ret; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::DelPyAttr(const py::str &name) { | |||
| std::string attr_name = name; | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->DelAttr(attr_name); | |||
| } else { | |||
| attrs_.erase(attr_name); | |||
| } | |||
| } | |||
| py::dict PrimitivePyAdapter::GetAttrDict() { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| return prim->GetAttrDict(); | |||
| } | |||
| py::dict attr_dict; | |||
| for (auto &attr : attrs_) { | |||
| attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | |||
| } | |||
| return attr_dict; | |||
| } | |||
| void PrimitivePyAdapter::set_prim_type(const PrimType t) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_prim_type(t); | |||
| } else { | |||
| prim_type_ = t; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_const_prim(bool is_const_prim) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_const_prim(is_const_prim); | |||
| } else { | |||
| is_const_prim_ = is_const_prim; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_const_input_indexes(const_input_indexes); | |||
| } else { | |||
| const_input_indexes_ = const_input_indexes; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_signatures(signatures); | |||
| } else { | |||
| signatures_ = signatures; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_hook(const py::function &hook) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_hook(hook); | |||
| } else { | |||
| hook_ = hook; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_instance_name(const std::string &s) { | |||
| auto prim = attached_primitive_.lock(); | |||
| if (prim != nullptr) { | |||
| prim->set_instance_name(s); | |||
| } else { | |||
| instance_name_ = s; | |||
| } | |||
| } | |||
| void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) { | |||
| if (attached_primitive_.lock() != nullptr) { | |||
| MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| attached_primitive_ = prim; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| @@ -386,18 +472,20 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom) | |||
| .value("py_infer_check", PrimType::kPrimTypePyInferCheck); | |||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | |||
| .def(py::init<py::str &, py::object>()) | |||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | |||
| .def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr") | |||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | |||
| .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") | |||
| .def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes, | |||
| (void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_) | |||
| .def(py::init<py::str &>()) | |||
| .def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr") | |||
| .def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr") | |||
| .def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.") | |||
| .def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.") | |||
| .def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes, | |||
| "Set primitive const input indexes.") | |||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | |||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | |||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | |||
| .def("set_signatures", &PrimitivePyAdapter::set_signatures, | |||
| "Set primitive inputs signature.") | |||
| .def("register_hook", &PrimitivePyAdapter::set_hook, "Set primitive hook function.") | |||
| .def("set_instance_name", &PrimitivePyAdapter::set_instance_name, | |||
| "Set primitive instance name."); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -34,9 +34,18 @@ | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| class PrimitivePy; | |||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | |||
| using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>; | |||
| class PrimitivePyAdapter; | |||
| using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>; | |||
| class PrimitivePy : public Primitive { | |||
| public: | |||
| PrimitivePy(const py::str &name, const py::object &python_obj); | |||
| explicit PrimitivePy(const std::string &name); | |||
| PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter); | |||
| ~PrimitivePy() override; | |||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||
| py::function GetBpropFunction(); | |||
| @@ -47,10 +56,6 @@ class PrimitivePy : public Primitive { | |||
| void CopyHookFunction(const PrimitivePtr &primitive) override; | |||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||
| void DelPyAttr(const py::str &name); | |||
| py::dict GetAttrDict(); | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| @@ -61,13 +66,13 @@ class PrimitivePy : public Primitive { | |||
| bool HasComputeFunction() const; | |||
| const bool parse_info_ = true; | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| void SetPyObj(const py::object &obj); | |||
| py::dict RunInfer(const py::tuple &args); | |||
| void RunCheck(const py::tuple &args); | |||
| py::object RunInferValue(const py::tuple &args); | |||
| bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } | |||
| bool HasPyObj() { return python_obj_.operator bool(); } | |||
| PrimitivePtr Clone() override; | |||
| PrimitivePyAdapterPtr adapter() const { return adapter_; } | |||
| bool is_tuple_input_ = false; | |||
| private: | |||
| @@ -75,11 +80,41 @@ class PrimitivePy : public Primitive { | |||
| void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const; | |||
| void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const; | |||
| py::object python_obj_; | |||
| PrimitivePyAdapterPtr adapter_; | |||
| py::function hook_; | |||
| std::vector<Signature> signatures_; | |||
| static std::map<std::string, py::object> hook_grad_; | |||
| }; | |||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | |||
| class PrimitivePyAdapter { | |||
| public: | |||
| explicit PrimitivePyAdapter(const py::str &name); | |||
| ~PrimitivePyAdapter() = default; | |||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||
| void DelPyAttr(const py::str &name); | |||
| py::dict GetAttrDict(); | |||
| void set_prim_type(const PrimType t); | |||
| void set_const_prim(bool is_const_prim); | |||
| void set_const_input_indexes(const std::vector<size_t> &const_input_indexes); | |||
| void set_signatures(const std::vector<Signature> &signatures); | |||
| void set_hook(const py::function &hook); | |||
| void set_instance_name(const std::string &s); | |||
| void set_attached_primitive(const PrimitivePyPtr &prim); | |||
| PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| const bool parse_info_ = true; | |||
| private: | |||
| friend PrimitivePy; | |||
| std::string name_; | |||
| PrimitivePyWeakPtr attached_primitive_; | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| PrimType prim_type_{kPrimTypeBuiltIn}; | |||
| bool is_const_prim_{false}; | |||
| std::vector<size_t> const_input_indexes_; | |||
| std::vector<Signature> signatures_; | |||
| py::function hook_; | |||
| std::string instance_name_; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_ | |||
| @@ -99,7 +99,7 @@ class Primitive : public Named { | |||
| } | |||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | |||
| virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } | |||
| void set_instance_name(const std::string s) { instance_name_ = s; } | |||
| void set_instance_name(const std::string &s) { instance_name_ = s; } | |||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | |||
| @@ -50,7 +50,7 @@ class Primitive(Primitive_): | |||
| self.attrs = {} | |||
| self.init_attrs = {"name": name} | |||
| self._update_parameter = False | |||
| Primitive_.__init__(self, name, self) | |||
| Primitive_.__init__(self, name) | |||
| if hasattr(self.__class__, '__mindspore_signature__'): | |||
| out = self._fill_signature(self.__class__.__mindspore_signature__) | |||
| self.set_signatures(out) | |||
| @@ -94,13 +94,13 @@ TEST_F(TestCompileSegmentRunner, test_if) { | |||
| TEST_F(TestCompileSegmentRunner, test_RunOperation1) { | |||
| VectorRef args({1}); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name())), args); | |||
| ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1); | |||
| } | |||
| TEST_F(TestCompileSegmentRunner, test_RunOperation2) { | |||
| VectorRef args({1, 2}); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args); | |||
| auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name())), args); | |||
| ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false); | |||
| } | |||
| } // namespace compile | |||