GitOrigin-RevId: e499f3ebf8
tags/v1.9.0
| @@ -373,6 +373,14 @@ public: | |||
| //! dump network after global layout transform optimization | |||
| static void dump_layout_transform_model( | |||
| std::shared_ptr<Network> network, std::string optimized_model_path); | |||
| //! get the model io information before model loaded by model path. | |||
| static NetworkIO get_model_io_info( | |||
| const std::string& model_path, const Config& config = {}); | |||
| //! get the model io information before model loaded by model memory. | |||
| static NetworkIO get_model_io_info( | |||
| const void* model_mem, size_t size, const Config& config = {}); | |||
| }; | |||
| } // namespace lite | |||
| @@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network); | |||
| LITE_API int LITE_dump_layout_transform_model( | |||
| LiteNetwork network, const char* dump_file_path); | |||
| /**! get the model io information before model loaded by model path. | |||
| * \param[in] model_path The model file path | |||
| * \param[in] config The model config for loading | |||
| * \param[out] ios The model io infermation | |||
| * \return int if the return is not zero, error happened, the error message | |||
| * can get by LITE_get_last_error | |||
| */ | |||
| LITE_API int LITE_get_model_io_info_by_path( | |||
| const char* model_path, const LiteConfig config, LiteNetworkIO* ios); | |||
| /** get the model io information before model loaded by model memory. | |||
| * \param[in] model_mem The model memory ptr | |||
| * \param[in] size The model memory ptr length | |||
| * \param[in] config The model config for loading | |||
| * \param[out] ios The model io infermation | |||
| * \return int if the return is not zero, error happened, the error message | |||
| * can get by LITE_get_last_error | |||
| */ | |||
| LITE_API int LITE_get_model_io_info_by_memory( | |||
| const void* model_mem, size_t size, const LiteConfig config, | |||
| LiteNetworkIO* ios); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) { | |||
| return network_io; | |||
| } | |||
| struct InnerIO { | |||
| std::vector<std::string> names; | |||
| std::vector<LiteIO> inputs; | |||
| std::vector<LiteIO> outputs; | |||
| }; | |||
| InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||
| InnerIO innner_io; | |||
| for (size_t i = 0; i < network_io.inputs.size(); i++) { | |||
| lite::IO io = network_io.inputs[i]; | |||
| innner_io.names.push_back(io.name); | |||
| innner_io.inputs.push_back( | |||
| {innner_io.names.back().c_str(), io.is_host, io.io_type, | |||
| convert_to_clayout(io.config_layout)}); | |||
| } | |||
| for (size_t i = 0; i < network_io.outputs.size(); i++) { | |||
| lite::IO io = network_io.outputs[i]; | |||
| innner_io.names.push_back(io.name); | |||
| innner_io.outputs.push_back( | |||
| {innner_io.names.back().c_str(), io.is_host, io.io_type, | |||
| convert_to_clayout(io.config_layout)}); | |||
| } | |||
| return innner_io; | |||
| } | |||
| int LITE_make_default_network(LiteNetwork* network) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
| @@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_ | |||
| lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); | |||
| LITE_CAPI_END(); | |||
| } | |||
| namespace { | |||
| static LITE_MUTEX mtx_io; | |||
| static std::unordered_map<const void*, InnerIO>& get_global_io_holder() { | |||
| static std::unordered_map<const void*, InnerIO> global_holder; | |||
| return global_holder; | |||
| } | |||
| int write_ios_from_cpp_io( | |||
| const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_LOCK_GUARD(mtx_io); | |||
| get_global_io_holder()[key] = convert_to_inner_io(cpp_io); | |||
| auto&& inner_io = get_global_io_holder()[key]; | |||
| ios->input_size = inner_io.inputs.size(); | |||
| ios->output_size = inner_io.outputs.size(); | |||
| ios->inputs = inner_io.inputs.data(); | |||
| ios->outputs = inner_io.outputs.data(); | |||
| size_t i = 0; | |||
| for (; i < ios->input_size; i++) { | |||
| auto io_ptr = ios->inputs + i; | |||
| io_ptr->name = inner_io.names[i].c_str(); | |||
| } | |||
| for (; i < ios->output_size; i++) { | |||
| auto io_ptr = ios->outputs + i; | |||
| io_ptr->name = inner_io.names[i].c_str(); | |||
| } | |||
| LITE_CAPI_END(); | |||
| } | |||
| } // namespace | |||
| int LITE_get_model_io_info_by_path( | |||
| const char* model_path, const LiteConfig config, LiteNetworkIO* ios) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(model_path, "The model_path pass to LITE api is null"); | |||
| auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||
| std::string{model_path}, convert_to_lite_config(config)); | |||
| return write_ios_from_cpp_io( | |||
| cpp_ios, ios, reinterpret_cast<const void*>(model_path)); | |||
| LITE_CAPI_END(); | |||
| } | |||
| int LITE_get_model_io_info_by_memory( | |||
| const void* model_mem, size_t size, const LiteConfig config, | |||
| LiteNetworkIO* ios) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null"); | |||
| auto&& cpp_ios = lite::Runtime::get_model_io_info( | |||
| model_mem, size, convert_to_lite_config(config)); | |||
| return write_ios_from_cpp_io( | |||
| cpp_ios, ios, reinterpret_cast<const void*>(model_mem)); | |||
| LITE_CAPI_END(); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase): | |||
| ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | |||
| ("LITE_enable_global_layout_transform", [_Cnetwork]), | |||
| ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | |||
| ( | |||
| "LITE_get_model_io_info_by_path", | |||
| [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)], | |||
| ), | |||
| ( | |||
| "LITE_get_model_io_info_by_memory", | |||
| [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | |||
| ), | |||
| ] | |||
| @@ -619,3 +627,27 @@ class LiteNetwork(object): | |||
| def dump_layout_transform_model(self, model_file): | |||
| c_file = model_file.encode("utf-8") | |||
| self._api.LITE_dump_layout_transform_model(self._network, c_file) | |||
| def get_model_io_info(model_path, config=None): | |||
| """ | |||
| get the model IO information before create the NetWork, this IO | |||
| information can be used to configuration the NetWork. | |||
| """ | |||
| api = _NetworkAPI()._lib | |||
| c_path = c_char_p(model_path.encode("utf-8")) | |||
| ios = _LiteNetworkIO() | |||
| if config is not None: | |||
| api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||
| else: | |||
| config = LiteConfig() | |||
| api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) | |||
| ret_ios = LiteNetworkIO() | |||
| for i in range(ios.input_size): | |||
| ret_ios.add_input(ios.inputs[i]) | |||
| for i in range(ios.output_size): | |||
| ret_ios.add_output(ios.outputs[i]) | |||
| return ret_ios | |||
| @@ -8,6 +8,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import os | |||
| import numpy as np | |||
| @@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy(): | |||
| for i in range(4): | |||
| for j in range(48): | |||
| assert data[i][j // 8][j % 8] == i + 1 | |||
| def test_get_model_io_ahead(): | |||
| source_dir = os.getenv("LITE_TEST_RESOURCE") | |||
| model_path = os.path.join(source_dir, "shufflenet.mge") | |||
| ios = get_model_io_info(model_path) | |||
| assert len(ios.inputs) == 1 | |||
| assert ios.inputs[0].name == "data" | |||
| assert ios.inputs[0].config_layout.shapes[1] == 3 | |||
| assert ios.inputs[0].config_layout.shapes[2] == 224 | |||
| assert ios.inputs[0].config_layout.shapes[3] == 224 | |||
| assert len(ios.outputs) == 1 | |||
| assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]" | |||
| assert ios.outputs[0].config_layout.shapes[0] == 1 | |||
| assert ios.outputs[0].config_layout.shapes[1] == 1000 | |||
| @@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft); | |||
| } // namespace | |||
| // if it can't find the function, ignore | |||
| template <typename tensor_type, typename ret_type, typename... Args> | |||
| template <typename type, typename ret_type, typename... Args> | |||
| ret_type try_call_func(std::string func_name, Args... args) { | |||
| mark_used_variable(func_name); | |||
| mark_used_variable(args...); | |||
| @@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) { | |||
| } | |||
| // if it can't find the function, throw error | |||
| template <typename tensor_type, typename ret_type, typename... Args> | |||
| template <typename type, typename ret_type, typename... Args> | |||
| ret_type call_func(std::string func_name, Args... args) { | |||
| mark_used_variable(args...); | |||
| auto backend_name = class_type_name<tensor_type>()(); | |||
| auto backend_name = class_type_name<type>()(); | |||
| auto msg_info = func_name + " is not aviliable in " + backend_name + " backend."; | |||
| LITE_THROW(msg_info.c_str()); | |||
| } | |||
| @@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>( | |||
| THROW_FUNC_ERROR(func_name); | |||
| } | |||
| } | |||
| template <> | |||
| inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||
| std::string func_name, std::string model_path, Config config) { | |||
| if (func_name == "get_model_io_info") { | |||
| return get_model_io_info_dft(model_path, config); | |||
| } else { | |||
| THROW_FUNC_ERROR(func_name); | |||
| } | |||
| } | |||
| template <> | |||
| inline NetworkIO call_func<NetworkImplDft, NetworkIO>( | |||
| std::string func_name, const void* model_mem, size_t size, Config config) { | |||
| if (func_name == "get_model_io_info") { | |||
| return get_model_io_info_dft(model_mem, size, config); | |||
| } else { | |||
| THROW_FUNC_ERROR(func_name); | |||
| } | |||
| } | |||
| #undef THROW_FUNC_ERROR | |||
| } // namespace lite | |||
| @@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat | |||
| "enable_global_layout_transform before")); | |||
| } | |||
| } | |||
| NetworkIO lite::get_model_io_info_dft( | |||
| const std::string& model_path, const Config& config) { | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| void* ptr = malloc(size); | |||
| std::shared_ptr<void> buf{ptr, ::free}; | |||
| auto nr = fread(buf.get(), 1, size, fin); | |||
| LITE_ASSERT(nr == size); | |||
| fclose(fin); | |||
| return get_model_io_info_dft(ptr, size, config); | |||
| } | |||
| NetworkIO lite::get_model_io_info_dft( | |||
| const void* model_mem, size_t size, const Config& config) { | |||
| std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}}; | |||
| auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false); | |||
| auto format = | |||
| mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file); | |||
| if (!format.valid()) { | |||
| LITE_THROW("invalid model format"); | |||
| } | |||
| auto loader = | |||
| mgb::serialization::GraphLoader::make(std::move(input_file), format.val()); | |||
| mgb::serialization::GraphLoadConfig load_config; | |||
| load_config.comp_graph = mgb::ComputingGraph::make(); | |||
| if (config.has_compression) { | |||
| load_config.tensor_value_loader = decompressed_tensor_value_loader; | |||
| } | |||
| auto compnode_locator = to_compnode_locator(config.device_type); | |||
| load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) { | |||
| if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { | |||
| loc.type = compnode_locator.type; | |||
| } | |||
| loc.device = compnode_locator.device; | |||
| }; | |||
| auto load_result = loader->load(load_config, true); | |||
| NetworkIO IOs; | |||
| for (auto&& in_tensor_iter : load_result.tensor_map) { | |||
| IO in_io; | |||
| in_io.name = in_tensor_iter.first; | |||
| in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout()); | |||
| IOs.inputs.push_back(in_io); | |||
| } | |||
| auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* { | |||
| auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager(); | |||
| using InferType = mgb::cg::static_infer::InferType; | |||
| if (static_infer_mgr.get_infer_type(var.node()).shape & | |||
| (InferType::CONST | InferType::RT_STATIC)) { | |||
| return static_infer_mgr.infer_shape_fallible(var.node()); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| }; | |||
| for (auto&& out : load_result.output_var_list) { | |||
| IO out_io; | |||
| out_io.name = out.node()->name(); | |||
| if (auto shape = infer_shape(out)) { | |||
| out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()}); | |||
| } else { | |||
| out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()}); | |||
| } | |||
| IOs.outputs.push_back(out_io); | |||
| } | |||
| return IOs; | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -262,6 +262,13 @@ private: | |||
| #endif | |||
| std::unique_ptr<mgb::OprIODumpBase> m_iodump; | |||
| }; | |||
| //! get the model information before model loaded by Network | |||
| NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config); | |||
| //! get the model information before model loaded by Network by model memory and | |||
| //! size | |||
| NetworkIO get_model_io_info_dft( | |||
| const void* model_mem, size_t size, const Config& config); | |||
| } // namespace lite | |||
| @@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model( | |||
| LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); | |||
| LITE_ERROR_HANDLER_END | |||
| } | |||
| NetworkIO Runtime::get_model_io_info( | |||
| const std::string& model_path, const Config& config) { | |||
| LITE_ERROR_HANDLER_BEGIN | |||
| if (config.backend == LiteBackend::LITE_DEFAULT) { | |||
| return call_func<NetworkImplDft, NetworkIO>( | |||
| "get_model_io_info", model_path, config); | |||
| } | |||
| LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||
| LITE_ERROR_HANDLER_END | |||
| } | |||
| NetworkIO Runtime::get_model_io_info( | |||
| const void* model_mem, size_t size, const Config& config) { | |||
| LITE_ERROR_HANDLER_BEGIN | |||
| if (config.backend == LiteBackend::LITE_DEFAULT) { | |||
| return call_func<NetworkImplDft, NetworkIO>( | |||
| "get_model_io_info", model_mem, size, config); | |||
| } | |||
| LITE_THROW("get_model_io_info is not aviliable in the backend."); | |||
| LITE_ERROR_HANDLER_END | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) { | |||
| ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
| } | |||
| TEST(TestNetWork, GetAllIoInfoAhead) { | |||
| Config config; | |||
| std::string model_path = "./shufflenet.mge"; | |||
| auto ios = Runtime::get_model_io_info(model_path); | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| ASSERT_TRUE(fin); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| void* ptr = malloc(size); | |||
| std::shared_ptr<void> buf{ptr, ::free}; | |||
| auto nr = fread(buf.get(), 1, size, fin); | |||
| LITE_ASSERT(nr == size); | |||
| fclose(fin); | |||
| auto ios_mem = Runtime::get_model_io_info(ptr, size); | |||
| ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size()); | |||
| ASSERT_EQ(ios.inputs.size(), 1); | |||
| ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size()); | |||
| ASSERT_EQ(ios.outputs.size(), 1); | |||
| ASSERT_TRUE(ios.inputs[0].name == "data"); | |||
| ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
| ASSERT_TRUE(ios_mem.inputs[0].name == "data"); | |||
| ASSERT_TRUE( | |||
| ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
| ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4); | |||
| ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3); | |||
| ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224); | |||
| ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2); | |||
| ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1); | |||
| ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000); | |||
| ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4); | |||
| ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3); | |||
| ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224); | |||
| ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2); | |||
| ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1); | |||
| ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000); | |||
| } | |||
| TEST(TestNetWork, LoadFBSModel) { | |||
| Config config; | |||
| std::string model_path = "./ax.mge"; | |||
| @@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) { | |||
| LITE_destroy_network(c_network); | |||
| } | |||
| TEST(TestCapiNetWork, GetAllNameAhead) { | |||
| std::string model_path = "./shufflenet.mge"; | |||
| LiteNetworkIO ios, ios_mem; | |||
| LITE_CAPI_CHECK(LITE_get_model_io_info_by_path( | |||
| model_path.c_str(), *default_config(), &ios)); | |||
| FILE* fin = fopen(model_path.c_str(), "rb"); | |||
| ASSERT_TRUE(fin); | |||
| fseek(fin, 0, SEEK_END); | |||
| size_t size = ftell(fin); | |||
| fseek(fin, 0, SEEK_SET); | |||
| void* ptr = malloc(size); | |||
| std::shared_ptr<void> buf{ptr, ::free}; | |||
| auto nr = fread(buf.get(), 1, size, fin); | |||
| LITE_ASSERT(nr == size); | |||
| fclose(fin); | |||
| LITE_CAPI_CHECK( | |||
| LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem)); | |||
| ASSERT_EQ(ios.input_size, 1); | |||
| ASSERT_EQ(ios.output_size, 1); | |||
| ASSERT_EQ(ios_mem.input_size, 1); | |||
| ASSERT_EQ(ios_mem.output_size, 1); | |||
| ASSERT_TRUE(std::string(ios.inputs->name) == "data"); | |||
| ASSERT_TRUE(ios.inputs->config_layout.ndim == 4); | |||
| ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3); | |||
| ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224); | |||
| ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224); | |||
| ASSERT_TRUE( | |||
| std::string(ios.outputs->name) == | |||
| "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
| ASSERT_TRUE(ios.outputs->config_layout.ndim == 2); | |||
| ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1); | |||
| ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000); | |||
| ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data"); | |||
| ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4); | |||
| ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3); | |||
| ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224); | |||
| ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224); | |||
| ASSERT_TRUE( | |||
| std::string(ios_mem.outputs->name) == | |||
| "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||
| ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2); | |||
| ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1); | |||
| ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000); | |||
| } | |||
| #if LITE_BUILD_WITH_RKNPU | |||
| static int GetTop( | |||