Merge pull request !607 from fary86/optimize_flow_of_exporting_onnx_modeltags/v0.3.0-alpha
| @@ -294,6 +294,30 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { | |||||
| MS_LOG(INFO) << "End save compiled func graph!"; | MS_LOG(INFO) << "End save compiled func graph!"; | ||||
| } | } | ||||
| void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) { | |||||
| #ifdef ENABLE_DUMP_IR | |||||
| // save the graph to file in protobuf format | |||||
| FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::string name_prefix = phase_s.substr(0, phase_s.find(".")); | |||||
| std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb"; | |||||
| std::string filename = GetFilePathName(pb_filename); | |||||
| MS_LOG(INFO) << "Begin saving graph to file <<'" << filename << "' in protobuf formart."; | |||||
| ChangeFileMode(filename, S_IRWXU); | |||||
| std::ofstream ofs(filename); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; | |||||
| return; | |||||
| } | |||||
| ofs << GetFuncGraphProtoString(func_graph); | |||||
| ofs.close(); | |||||
| // set file mode to read only by user | |||||
| ChangeFileMode(filename, S_IRUSR); | |||||
| MS_LOG(INFO) << "End saving graph to file in protobuf format"; | |||||
| #endif | |||||
| } | |||||
| bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { | bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { | ||||
| std::string phase_prefix = GetPhasePrefix(phase_s); | std::string phase_prefix = GetPhasePrefix(phase_s); | ||||
| @@ -365,6 +389,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||||
| info_[phase_s] = executor_info; | info_[phase_s] = executor_info; | ||||
| pip->Run(); | pip->Run(); | ||||
| // save compile graph to file in protobuf format | |||||
| SaveCompiledGraphToPb(phase_s); | |||||
| // save the run graph func to MsPipeLine | // save the run graph func to MsPipeLine | ||||
| SaveCompiledGraph(phase_s); | SaveCompiledGraph(phase_s); | ||||
| @@ -557,20 +583,6 @@ void Pipeline::Run() { | |||||
| std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); | std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); | ||||
| MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; | MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; | ||||
| draw::DrawUserFuncGraph(user_graph_file, user_graph); | draw::DrawUserFuncGraph(user_graph_file, user_graph); | ||||
| #ifdef ENABLE_DUMP_IR | |||||
| std::string filename = GetFilePathName("ms_output.pb"); | |||||
| ChangeFileMode(filename, S_IRWXU); | |||||
| std::ofstream ofs(filename); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; | |||||
| return; | |||||
| } | |||||
| ofs << GetFuncGraphProtoString(user_graph); | |||||
| ofs.close(); | |||||
| // set file mode to read only by user | |||||
| ChangeFileMode(filename, S_IRUSR); | |||||
| #endif | |||||
| } | } | ||||
| MS_LOG(INFO) << "End"; | MS_LOG(INFO) << "End"; | ||||
| } | } | ||||
| @@ -70,6 +70,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||||
| ~ExecutorPy(); | ~ExecutorPy(); | ||||
| void SaveCompiledGraph(const std::string &phase_s); | void SaveCompiledGraph(const std::string &phase_s); | ||||
| void SaveCompiledGraphToPb(const std::string &phase_s); | |||||
| bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); | bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); | ||||
| bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); | bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); | ||||
| @@ -158,7 +158,7 @@ void Profile::Print(void) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| PrintProfile(oss, *ctx_ptr_->time_info_); | PrintProfile(oss, *ctx_ptr_->time_info_); | ||||
| std::string text = oss.str(); | std::string text = oss.str(); | ||||
| // the length of text is too long to use MS_LOGINFO, use printf to print it | |||||
| // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace | |||||
| (void)printf("%s", text.c_str()); | (void)printf("%s", text.c_str()); | ||||
| (void)fflush(stdout); | (void)fflush(stdout); | ||||
| } | } | ||||
| @@ -358,7 +358,7 @@ void MsProfile::Print() { | |||||
| PrintTimeStat(oss, groups[i], prefix); | PrintTimeStat(oss, groups[i], prefix); | ||||
| } | } | ||||
| std::string text = oss.str(); | std::string text = oss.str(); | ||||
| // the length of text is too long to use MS_LOGINFO, use printf to print it | |||||
| // here use printf to output profile info, not use MS_LOG(INFO) since when open log, it affects performace | |||||
| (void)printf("\nTime group info:\n%s", text.c_str()); | (void)printf("\nTime group info:\n%s", text.c_str()); | ||||
| (void)fflush(stdout); | (void)fflush(stdout); | ||||
| } | } | ||||
| @@ -328,7 +328,7 @@ class _Executor: | |||||
| raise TypeError('Parameters need OrderedDict type, but got {}'. | raise TypeError('Parameters need OrderedDict type, but got {}'. | ||||
| format(type(params))) | format(type(params))) | ||||
| def compile(self, obj, *args, phase='predict', params=None): | |||||
| def compile(self, obj, *args, phase='predict', params=None, do_convert=True): | |||||
| """ | """ | ||||
| Compiles graph. | Compiles graph. | ||||
| @@ -337,6 +337,7 @@ class _Executor: | |||||
| args (tuple): Function or cell input arguments. | args (tuple): Function or cell input arguments. | ||||
| phase (str): The name of compile phase. Default: 'predict'. | phase (str): The name of compile phase. Default: 'predict'. | ||||
| params (OrderedDict): The parameters dictionary used for init data graph. Default: None. | params (OrderedDict): The parameters dictionary used for init data graph. Default: None. | ||||
| do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. | |||||
| Return: | Return: | ||||
| Str, the full phase of the cell. | Str, the full phase of the cell. | ||||
| @@ -368,7 +369,8 @@ class _Executor: | |||||
| if graph is None: | if graph is None: | ||||
| logger.error("%r graph compile failed.", phase) | logger.error("%r graph compile failed.", phase) | ||||
| if not do_convert: | |||||
| return phase, True | |||||
| if not enable_debug_runtime or enable_ge: | if not enable_debug_runtime or enable_ge: | ||||
| if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: | if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: | ||||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | ||||
| @@ -450,7 +450,7 @@ def export(net, *inputs, file_name, file_format='GEIR'): | |||||
| _executor.export(net, file_name, file_format) | _executor.export(net, file_name, file_format) | ||||
| elif file_format == 'ONNX': # file_format is 'ONNX' | elif file_format == 'ONNX': # file_format is 'ONNX' | ||||
| phase_name = 'export_onnx' | phase_name = 'export_onnx' | ||||
| graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) | |||||
| graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) | |||||
| onnx_stream = _executor._get_func_graph_proto(graph_id) | onnx_stream = _executor._get_func_graph_proto(graph_id) | ||||
| with open(file_name, 'wb') as f: | with open(file_name, 'wb') as f: | ||||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | ||||