GitOrigin-RevId: 910c8da19f
tags/v1.9.0
| @@ -40,7 +40,6 @@ public: | |||
| void wait() override; | |||
| //! enable global layout transform | |||
| void set_layout_transform(bool state) { enable_layout_transform = state; } | |||
| //! get the network of lite model | |||
| @@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet): | |||
| fi = open("./model_afer_layoutTrans.mgb", "r") | |||
| fi.close() | |||
| os.remove("./model_afer_layoutTrans.mgb") | |||
| def test_fast_run_and_global_layout_transform(self): | |||
| config_ = LiteConfig() | |||
| network = LiteNetwork(config_) | |||
| fast_run_cache = "./algo_cache" | |||
| global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||
| network.set_network_algo_policy( | |||
| LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||
| | LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||
| ) | |||
| network.enable_global_layout_transform() | |||
| network.load(self.model_path) | |||
| self.do_forward(network) | |||
| network.dump_layout_transform_model(global_layout_transform_model) | |||
| LiteGlobal.dump_persistent_cache(fast_run_cache) | |||
| fi = open(fast_run_cache, "r") | |||
| fi.close() | |||
| fi = open(global_layout_transform_model, "r") | |||
| fi.close() | |||
| LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||
| self.do_forward(network) | |||
| os.remove(fast_run_cache) | |||
| os.remove(global_layout_transform_model) | |||
| @@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda): | |||
| fi = open("./model_afer_layoutTrans.mgb", "r") | |||
| fi.close() | |||
| os.remove("./model_afer_layoutTrans.mgb") | |||
| @require_cuda() | |||
| def test_fast_run_and_global_layout_transform(self): | |||
| config_ = LiteConfig() | |||
| config_.device_type = LiteDeviceType.LITE_CUDA | |||
| network = LiteNetwork(config_) | |||
| fast_run_cache = "./algo_cache" | |||
| global_layout_transform_model = "./model_afer_layoutTrans.mgb" | |||
| network.set_network_algo_policy( | |||
| LiteAlgoSelectStrategy.LITE_ALGO_PROFILE | |||
| | LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED | |||
| ) | |||
| network.enable_global_layout_transform() | |||
| network.load(self.model_path) | |||
| self.do_forward(network) | |||
| network.dump_layout_transform_model(global_layout_transform_model) | |||
| LiteGlobal.dump_persistent_cache(fast_run_cache) | |||
| fi = open(fast_run_cache, "r") | |||
| fi.close() | |||
| fi = open(global_layout_transform_model, "r") | |||
| fi.close() | |||
| LiteGlobal.set_persistent_cache(path=fast_run_cache) | |||
| self.do_forward(network) | |||
| os.remove(fast_run_cache) | |||
| os.remove(global_layout_transform_model) | |||
| @@ -422,6 +422,8 @@ void NetworkImplDft::load_model( | |||
| m_load_result = m_loader->load(m_load_config, true); | |||
| modify_exection_policy(); | |||
| global_layout_transform(); | |||
| adapt_option_valid(); | |||
| @@ -436,7 +438,6 @@ void NetworkImplDft::load_model( | |||
| } | |||
| void NetworkImplDft::compile_graph() { | |||
| modify_exection_policy(); | |||
| replace_dev_input_pass(); | |||
| make_output_spec(); | |||
| m_execute_func = m_load_result.graph_compile(m_output_spec); | |||
| @@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy( | |||
| if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { | |||
| dst_strategy = dst_strategy | S::OPTIMIZED; | |||
| } | |||
| m_execution_policy = dst_strategy; | |||
| if (static_cast<uint32_t>(dst_strategy) != 0) | |||
| m_execution_policy = dst_strategy; | |||
| auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; | |||
| fast_run_config.binary_equal_between_batch = binary_equal_between_batch; | |||
| @@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy( | |||
| } | |||
| void NetworkImplDft::modify_exection_policy() { | |||
| mgb::SymbolVarArray vars; | |||
| for (auto i : m_output_spec) { | |||
| vars.push_back(i.first); | |||
| } | |||
| if (static_cast<uint32_t>(m_execution_policy) != 0) | |||
| auto& vars = m_load_result.output_var_list; | |||
| if (static_cast<uint32_t>(m_execution_policy) != 0) { | |||
| mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); | |||
| } | |||
| } | |||
| //! set opr algorithm selection strategy in the network | |||
| @@ -289,21 +289,21 @@ namespace intl { | |||
| template <typename Opr> | |||
| struct OprFormatModifier; | |||
| #define INST(_Opr) \ | |||
| template <> \ | |||
| struct OprFormatModifier<_Opr> { \ | |||
| using OprFormat = typename _Opr::Param::Format; \ | |||
| static VarNode* make( \ | |||
| OprFormat opr_format, const VarNodeArray& i, \ | |||
| const cg::OperatorNodeBase* opr_) { \ | |||
| MIDOUT_B(_Opr) \ | |||
| auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||
| auto param = opr.param(); \ | |||
| param.format = opr_format; \ | |||
| return OprWithPolicyMaker<_Opr>::make( \ | |||
| i, param, opr.execution_policy(), opr.config()); \ | |||
| MIDOUT_E \ | |||
| } \ | |||
| #define INST(_Opr) \ | |||
| template <> \ | |||
| struct OprFormatModifier<_Opr> { \ | |||
| using OprFormat = typename _Opr::Param::Format; \ | |||
| static VarNode* make( \ | |||
| OprFormat opr_format, const VarNodeArray& i, \ | |||
| const cg::OperatorNodeBase* opr_) { \ | |||
| MIDOUT_B(_Opr) \ | |||
| auto&& opr = opr_->cast_final_safe<_Opr>(); \ | |||
| auto param = opr.param(); \ | |||
| param.format = opr_format; \ | |||
| return OprWithPolicyMaker<_Opr>::make( \ | |||
| i, param, opr.execution_policy_transient(), opr.config()); \ | |||
| MIDOUT_E \ | |||
| } \ | |||
| }; | |||
| INST(Convolution); | |||
| INST(ConvBiasForward); | |||