GitOrigin-RevId: f159f49208
tags/v1.8.0
| @@ -97,7 +97,7 @@ struct LITE_API Options { | |||||
| bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
| uint8_t jit_level = 0; | uint8_t jit_level = 0; | ||||
| uint8_t comp_node_seq_record_level = 0; | uint8_t comp_node_seq_record_level = 0; | ||||
| uint8_t graph_opt_level = 0; | |||||
| uint8_t graph_opt_level = 2; | |||||
| uint16_t async_exec_level = 1; | uint16_t async_exec_level = 1; | ||||
| //! layout transform options | //! layout transform options | ||||
| @@ -368,7 +368,6 @@ public: | |||||
| const std::shared_ptr<Network> src_network); | const std::shared_ptr<Network> src_network); | ||||
| //! set global layout transform optimization for network | //! set global layout transform optimization for network | ||||
| static void enable_global_layout_transform(std::shared_ptr<Network> network); | static void enable_global_layout_transform(std::shared_ptr<Network> network); | ||||
| //! dump network after global layout transform optimization | //! dump network after global layout transform optimization | ||||
| @@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase): | |||||
| ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | ||||
| ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | ||||
| ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ||||
| ("LITE_enable_global_layout_transform", [_Cnetwork]), | |||||
| ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | |||||
| ] | ] | ||||
| @@ -610,3 +612,10 @@ class LiteNetwork(object): | |||||
| def get_static_memory_alloc_info(self, log_dir="logs/test"): | def get_static_memory_alloc_info(self, log_dir="logs/test"): | ||||
| c_log_dir = log_dir.encode("utf-8") | c_log_dir = log_dir.encode("utf-8") | ||||
| self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) | self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) | ||||
| def enable_global_layout_transform(self): | |||||
| self._api.LITE_enable_global_layout_transform(self._network) | |||||
| def dump_layout_transform_model(self, model_file): | |||||
| c_file = model_file.encode("utf-8") | |||||
| self._api.LITE_dump_layout_transform_model(self._network, c_file) | |||||
| @@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet): | |||||
| network.wait() | network.wait() | ||||
| self.check_correct(out_array) | self.check_correct(out_array) | ||||
| def test_enable_global_layout_transform(self): | |||||
| network = LiteNetwork() | |||||
| network.enable_global_layout_transform() | |||||
| network.load(self.model_path) | |||||
| self.do_forward(network) | |||||
| def test_dump_layout_transform_model(self): | |||||
| network = LiteNetwork() | |||||
| network.enable_global_layout_transform() | |||||
| network.load(self.model_path) | |||||
| network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | |||||
| self.do_forward(network) | |||||
| fi = open("./model_afer_layoutTrans.mgb", "r") | |||||
| fi.close() | |||||
| os.remove("./model_afer_layoutTrans.mgb") | |||||
| @@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda): | |||||
| | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | ||||
| ) | ) | ||||
| self.do_forward(network) | self.do_forward(network) | ||||
| @require_cuda() | |||||
| def test_enable_global_layout_transform(self): | |||||
| network = LiteNetwork() | |||||
| network.enable_global_layout_transform() | |||||
| network.load(self.model_path) | |||||
| self.do_forward(network) | |||||
| @require_cuda() | |||||
| def test_dump_layout_transform_model(self): | |||||
| network = LiteNetwork() | |||||
| network.enable_global_layout_transform() | |||||
| network.load(self.model_path) | |||||
| network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | |||||
| self.do_forward(network) | |||||
| fi = open("./model_afer_layoutTrans.mgb", "r") | |||||
| fi.close() | |||||
| os.remove("./model_afer_layoutTrans.mgb") | |||||
| @@ -406,7 +406,7 @@ void NetworkImplDft::load_model( | |||||
| use_tensorrt(); | use_tensorrt(); | ||||
| } | } | ||||
| m_load_result = m_loader->load(m_load_config, false); | |||||
| m_load_result = m_loader->load(m_load_config, true); | |||||
| global_layout_transform(); | global_layout_transform(); | ||||
| @@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) { | |||||
| } | } | ||||
| TEST(TestNetWork, GlabalLayoutTransform) { | TEST(TestNetWork, GlabalLayoutTransform) { | ||||
| // set_log_level(LiteLogLevel::DEBUG); | |||||
| auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||
| std::string model_path = "./shufflenet.mge"; | std::string model_path = "./shufflenet.mge"; | ||||
| std::string input_name = "data"; | std::string input_name = "data"; | ||||
| @@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) { | |||||
| network->forward(); | network->forward(); | ||||
| network->wait(); | network->wait(); | ||||
| ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); | ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); | ||||
| remove(dump_model_name.c_str()); | |||||
| } | } | ||||
| TEST(TestNetWork, GetDeviceType) { | TEST(TestNetWork, GetDeviceType) { | ||||