GitOrigin-RevId: aea62de345
tags/v1.7.2.m1
| @@ -510,7 +510,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { | |||||
| std::vector<std::vector<std::shared_ptr<Tensor>>> inputs; | std::vector<std::vector<std::shared_ptr<Tensor>>> inputs; | ||||
| std::vector<std::vector<std::shared_ptr<Tensor>>> outputs; | std::vector<std::vector<std::shared_ptr<Tensor>>> outputs; | ||||
| std::shared_ptr<Network> network = std::make_shared<Network>(); | |||||
| Config config; | |||||
| config.options.graph_opt_level = 0; | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
| network->load_model(model_path); | network->load_model(model_path); | ||||
| input_names = network->get_all_input_name(); | input_names = network->get_all_input_name(); | ||||
| @@ -559,10 +562,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { | |||||
| outputs.push_back(net_outputs); | outputs.push_back(net_outputs); | ||||
| } | } | ||||
| Config config; | |||||
| config.options.force_output_use_user_specified_memory = true; | config.options.force_output_use_user_specified_memory = true; | ||||
| config.options.comp_node_seq_record_level = record; | config.options.comp_node_seq_record_level = record; | ||||
| config.options.const_shape = true; | config.options.const_shape = true; | ||||
| config.options.graph_opt_level = 2; | |||||
| std::shared_ptr<Network> network_record = std::make_shared<Network>(config); | std::shared_ptr<Network> network_record = std::make_shared<Network>(config); | ||||
| @@ -10,6 +10,7 @@ | |||||
| */ | */ | ||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -27,6 +28,35 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| void GraphLoader::LoadResult::graph_compile_ahead() { | |||||
| //! when force_output_use_user_specified_memory is set, the output var may | |||||
| //! be changed by gopt, then the var in LoadResult can not exist, so here | |||||
| //! just do basic optimize_for_inference ahead, and replace the var in | |||||
| //! LoadResult | |||||
| if (graph->options().force_output_use_user_specified_memory) { | |||||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||||
| auto new_vars = gopt::optimize_for_inference(output_var_list, options); | |||||
| output_var_list = new_vars; | |||||
| output_var_map.clear(); | |||||
| for (auto& var : new_vars) { | |||||
| output_var_map[var.node()->cname()] = var; | |||||
| } | |||||
| std::unordered_map<size_t, SymbolVar> var_map_id; | |||||
| for (auto& var : new_vars) { | |||||
| bool found = false; | |||||
| for (auto& old_var_it : output_var_map_id) { | |||||
| if (old_var_it.second.node()->name() == var.node()->name()) { | |||||
| found = true; | |||||
| var_map_id[old_var_it.first] = var; | |||||
| } | |||||
| } | |||||
| mgb_assert( | |||||
| found, "can't find var name %s when optimize_for_inference. ", | |||||
| var.node()->cname()); | |||||
| } | |||||
| } | |||||
| } | |||||
| GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { | GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { | ||||
| SharedTensorNameMap ret; | SharedTensorNameMap ret; | ||||
| for (auto&& i : shared_tensor_id_map()) { | for (auto&& i : shared_tensor_id_map()) { | ||||
| @@ -946,6 +946,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi | |||||
| mgb_assert(fbs_end > cur); | mgb_assert(fbs_end > cur); | ||||
| // Skip to Graph end | // Skip to Graph end | ||||
| m_file->skip(fbs_end - cur); | m_file->skip(fbs_end - cur); | ||||
| result.graph_compile_ahead(); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -63,6 +63,14 @@ public: | |||||
| */ | */ | ||||
| MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile( | MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile( | ||||
| const ComputingGraph::OutputSpec& outspec); | const ComputingGraph::OutputSpec& outspec); | ||||
| /*! | |||||
| * \brief after graph is loaded, do some basic optimized_for_inference, | |||||
| * because some dest var maybe replaced, case error when optimize flag | |||||
| * force_output_use_user_specified_memory is on | |||||
| * | |||||
| */ | |||||
| MGE_WIN_DECLSPEC_FUC void graph_compile_ahead(); | |||||
| }; | }; | ||||
| //! helper to disable inplace arith graph optimization during | //! helper to disable inplace arith graph optimization during | ||||