From: @zhangzhaoju Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/15525/MERGE
| @@ -25,12 +25,11 @@ from .parser import (Parser, create_obj_instance, generate_scope, | |||||
| get_ast_namespace_symbol, get_operation_namespace_symbol, | get_ast_namespace_symbol, get_operation_namespace_symbol, | ||||
| get_parse_method_of_class, get_scope_name, expand_expr_statement, | get_parse_method_of_class, get_scope_name, expand_expr_statement, | ||||
| is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description) | is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description) | ||||
| from .serialize import * | |||||
| __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | ||||
| 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', | 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', | ||||
| 'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', | 'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', | ||||
| 'get_args', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | 'get_args', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | ||||
| 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | ||||
| 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', | |||||
| 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name', | |||||
| 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement'] | 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement'] | ||||
| @@ -1,45 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """The functions in this file is used to dump and load python object in anf graphs.""" | |||||
| import pickle | |||||
| import os | |||||
| import stat | |||||
| def dump_obj(obj, path): | |||||
| """Dump object to file.""" | |||||
| file_name = hex(id(obj)) | |||||
| file_path = path + file_name | |||||
| with open(file_path, 'wb') as f: | |||||
| os.chmod(file_path, stat.S_IWUSR | stat.S_IRUSR) | |||||
| pickle.dump(obj, f) | |||||
| return file_name | |||||
| def load_obj(file_path): | |||||
| """Load object from file.""" | |||||
| obj = None | |||||
| try: | |||||
| real_file_path = os.path.realpath(file_path) | |||||
| except Exception as ex: | |||||
| raise RuntimeError(ex) | |||||
| with open(real_file_path, 'rb') as f: | |||||
| obj = pickle.load(f) | |||||
| return obj | |||||
| __all__ = ['dump_obj', 'load_obj'] | |||||
| @@ -45,42 +45,6 @@ | |||||
| using mindspore::tensor::TensorPy; | using mindspore::tensor::TensorPy; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| // max number of elements in sequence | |||||
| const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; | |||||
| // ============================================== MindSpore IR Common ============================================== | |||||
| // get MindSpore Intermediate Representation Path | |||||
| std::string GetMsIrPath(void) { | |||||
| std::string path; | |||||
| const char *path_ptr = getenv("MS_IR_PATH"); | |||||
| if (path_ptr != nullptr) { | |||||
| path = path_ptr; | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| if (path.size() > PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "MS IR Path error, " << path_ptr; | |||||
| } | |||||
| #else | |||||
| if (path.size() > PATH_MAX || nullptr == realpath(path.c_str(), real_path)) { | |||||
| MS_LOG(EXCEPTION) << "MS IR path error, " << path_ptr; | |||||
| } | |||||
| #endif | |||||
| path = real_path; | |||||
| } | |||||
| return path; | |||||
| } | |||||
| std::string dump_obj(const py::object &obj, const std::string &path) { | |||||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||||
| py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); | |||||
| return py::str(name); | |||||
| } | |||||
| py::object load_obj(const std::string &path) { | |||||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||||
| py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); | |||||
| return obj; | |||||
| } | |||||
| // ============================================= MindSpore IR Exporter ============================================= | // ============================================= MindSpore IR Exporter ============================================= | ||||
| @@ -98,17 +62,6 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { | |||||
| std::string pkl_path = GetMsIrPath(); | |||||
| // if not specified env 'MS_IR_PATH', do not create any files | |||||
| if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { | |||||
| return "null"; | |||||
| } | |||||
| std::string file_prefix = id_ + "." + category; | |||||
| std::string file_name = dump_obj(obj, pkl_path + "/" + file_prefix); | |||||
| return file_prefix + file_name; | |||||
| } | |||||
| int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { | int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { | ||||
| if (func_graph == nullptr || param == nullptr) { | if (func_graph == nullptr || param == nullptr) { | ||||
| return -1; | return -1; | ||||
| @@ -181,9 +134,6 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||||
| oss << py_func.first[i]->DumpText(); | oss << py_func.first[i]->DumpText(); | ||||
| } | } | ||||
| oss << ")"; | oss << ")"; | ||||
| // dump Python Function object | |||||
| oss << "@" << DumpObject(py_func.second, "F"); | |||||
| } | } | ||||
| oss << "}"; | oss << "}"; | ||||
| @@ -263,17 +213,8 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| oss << prim->type_name() << "::" << prim->name(); | oss << prim->type_name() << "::" << prim->name(); | ||||
| // need to serialize internal python function of PrimitivePy and record its prim_type | |||||
| if (prim->isa<PrimitivePy>()) { | |||||
| PrimitivePyPtr primpy = prim->cast<PrimitivePyPtr>(); | |||||
| // dump related function in PrimitivePy | |||||
| oss << "@" << DumpObject(primpy->GetPyObj(), "P"); | |||||
| // output primitive type | |||||
| oss << "{prim_type=" << static_cast<int>(prim->prim_type()) << "}"; | |||||
| } | |||||
| // output primitive type | |||||
| oss << "{prim_type=" << static_cast<int>(prim->prim_type()) << "}"; | |||||
| // output primitive attributes | // output primitive attributes | ||||
| oss << prim->GetAttrsText(); | oss << prim->GetAttrsText(); | ||||
| @@ -296,7 +237,7 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { | |||||
| } | } | ||||
| // dump related module information in Namespace | // dump related module information in Namespace | ||||
| oss << ns->type_name() << "::" << ns->module() << "@" << DumpObject(ns->obj(), "N"); | |||||
| oss << ns->type_name() << "::" << ns->module(); | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| @@ -399,8 +340,7 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const Valu | |||||
| } else if (value->isa<Scalar>() || value->isa<StringImm>()) { | } else if (value->isa<Scalar>() || value->isa<StringImm>()) { | ||||
| oss << value->DumpText(); | oss << value->DumpText(); | ||||
| } else if (value->isa<tensor::Tensor>()) { | } else if (value->isa<tensor::Tensor>()) { | ||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T"); | |||||
| oss << value->DumpText(); | |||||
| } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) { | } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) { | ||||
| oss << value->DumpText(); | oss << value->DumpText(); | ||||
| } else if (value->isa<ValueSequeue>()) { | } else if (value->isa<ValueSequeue>()) { | ||||
| @@ -477,20 +417,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode | |||||
| } else { | } else { | ||||
| ofs << "%para" << param_index << " : " << type_info; | ofs << "%para" << param_index << " : " << type_info; | ||||
| } | } | ||||
| // dump Default value of parameter if exists | |||||
| const ParameterPtr param_ptr = dyn_cast<Parameter>(param); | |||||
| if (param_ptr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Param could not cast to parameter"; | |||||
| } | |||||
| if (param_ptr->has_default()) { | |||||
| auto param_value = param_ptr->default_param(); | |||||
| ofs << " = @" << DumpObject(py::cast(param_value), "D"); | |||||
| } | |||||
| // output comment | // output comment | ||||
| ofs << " # " << param->DumpText() << "\n"; | ofs << " # " << param->DumpText() << "\n"; | ||||
| param_index += 1; | param_index += 1; | ||||
| } | } | ||||
| } | } | ||||
| @@ -702,13 +630,13 @@ void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector | |||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | #ifdef ENABLE_DUMP_IR | ||||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { | |||||
| void ExportIR(const std::string &filename, const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto real_filename = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".dat")); | auto real_filename = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".dat")); | ||||
| AnfExporter exporter(id); | |||||
| AnfExporter exporter; | |||||
| ChangeFileMode(real_filename, S_IRWXU); | ChangeFileMode(real_filename, S_IRWXU); | ||||
| exporter.ExportFuncGraph(real_filename, func_graph); | exporter.ExportFuncGraph(real_filename, func_graph); | ||||
| // set file mode to read only by user | // set file mode to read only by user | ||||
| @@ -64,8 +64,8 @@ struct ParamPtrHasher { | |||||
| class AnfExporter { | class AnfExporter { | ||||
| public: | public: | ||||
| explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) | |||||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | |||||
| explicit AnfExporter(bool export_used = true, bool check_integrity = false) | |||||
| : param_index(-1), export_used_(export_used), check_integrity_(check_integrity) { | |||||
| func_graph_set.clear(); | func_graph_set.clear(); | ||||
| exported.clear(); | exported.clear(); | ||||
| } | } | ||||
| @@ -78,7 +78,6 @@ class AnfExporter { | |||||
| virtual std::string GetNodeType(const AnfNodePtr &nd); | virtual std::string GetNodeType(const AnfNodePtr &nd); | ||||
| int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); | int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); | ||||
| int GetParamIndexFromExported(const AnfNodePtr ¶m); | int GetParamIndexFromExported(const AnfNodePtr ¶m); | ||||
| std::string DumpObject(const py::object &obj, const std::string &category) const; | |||||
| std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); | std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); | ||||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); | std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); | ||||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); | std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); | ||||
| @@ -102,14 +101,12 @@ class AnfExporter { | |||||
| int param_index; | int param_index; | ||||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | OrderedSet<FuncGraphPtr> func_graph_set{}; | ||||
| OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; | OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported; | ||||
| std::string id_; | |||||
| bool export_used_ = true; // whether export function graphs used in current exporting function graph | bool export_used_ = true; // whether export function graphs used in current exporting function graph | ||||
| bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true | bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true | ||||
| TaggedNodeMap tagged_cnodes_; | TaggedNodeMap tagged_cnodes_; | ||||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | |||||
| }; | }; | ||||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); | |||||
| void ExportIR(const std::string &filename, const FuncGraphPtr &func_graph); | |||||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -118,7 +118,7 @@ void TraceGraphEval() { | |||||
| class AnalyzedFuncGraphExporter : public AnfExporter { | class AnalyzedFuncGraphExporter : public AnfExporter { | ||||
| public: | public: | ||||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | |||||
| AnalyzedFuncGraphExporter() : AnfExporter(true, false) {} | |||||
| ~AnalyzedFuncGraphExporter() override = default; | ~AnalyzedFuncGraphExporter() override = default; | ||||
| void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs); | void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs); | ||||
| @@ -282,7 +282,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con | |||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | |||||
| ExportIR(fg_name + ".dat", func_graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -202,7 +202,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | |||||
| ExportIR(fg_name + ".dat", func_graph); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ void DumpGraph(const FuncGraphPtr &root, const std::string &name) { | |||||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | ||||
| draw::Draw(name + ".dot", root); | draw::Draw(name + ".dot", root); | ||||
| DumpIR(name + ".ir", root); | DumpIR(name + ".ir", root); | ||||
| ExportIR(name + ".dat", "0", root); | |||||
| ExportIR(name + ".dat", root); | |||||
| } | } | ||||
| } | } | ||||
| @@ -423,7 +423,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | |||||
| ExportIR(fg_name + ".dat", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | ||||
| } | } | ||||
| counter++; | counter++; | ||||
| @@ -728,7 +728,7 @@ void Pipeline::Run() { | |||||
| DumpIR(base_name + ".ir", graph, false, kTopStack); | DumpIR(base_name + ".ir", graph, false, kTopStack); | ||||
| } | } | ||||
| // generate IR file in a heavily commented format, which can also be reloaded | // generate IR file in a heavily commented format, which can also be reloaded | ||||
| ExportIR(base_name + ".dat", std::to_string(i), graph); | |||||
| ExportIR(base_name + ".dat", graph); | |||||
| } | } | ||||
| i++; | i++; | ||||
| #ifdef ENABLE_TIMELINE | #ifdef ENABLE_TIMELINE | ||||
| @@ -1,31 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| @File : test_parse.py | |||||
| @Date : 2019-11-05 14:49 | |||||
| @Desc : | |||||
| """ | |||||
| import os | |||||
| from mindspore._extends.parse import dump_obj | |||||
| from mindspore._extends.parse import load_obj | |||||
| def test_load_dump(): | |||||
| data = (1, 3, 2, 7, 9) | |||||
| file_name = dump_obj(data, "./") | |||||
| obj = load_obj("./" + file_name) | |||||
| os.remove(f'./{file_name}') | |||||
| assert data == obj | |||||