GitOrigin-RevId: cd155a1fcf
tags/v1.10.0
| @@ -117,6 +117,17 @@ struct LITE_API Config { | |||||
| Options options = {}; | Options options = {}; | ||||
| }; | }; | ||||
| /*! | |||||
| * \brief Extra Configuration for a network | |||||
| * | |||||
| * \param disable_configure_by_model_info disable the configuration dumped with model, | |||||
| * if set true, all configuration in the model will not apply, users should configure | |||||
| * the network. | |||||
| */ | |||||
| struct LITE_API ExtraConfig { | |||||
| bool disable_configure_by_model_info = false; | |||||
| }; | |||||
| /*! | /*! | ||||
| * \brief config the network input and output item | * \brief config the network input and output item | ||||
| * | * | ||||
| @@ -275,6 +286,12 @@ public: | |||||
| //! get static peak memory info showed by Graph visualization | //! get static peak memory info showed by Graph visualization | ||||
| void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; | void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; | ||||
| /** @brief the extra configuration | |||||
| * | |||||
| * @param extra_config the extra configuration to set into the network | |||||
| */ | |||||
| void extra_configure(const ExtraConfig& extra_config); | |||||
| public: | public: | ||||
| friend class NetworkHelper; | friend class NetworkHelper; | ||||
| @@ -288,6 +305,7 @@ private: | |||||
| private: | private: | ||||
| bool m_loaded = false; | bool m_loaded = false; | ||||
| Config m_config; | Config m_config; | ||||
| ExtraConfig m_extra_config; | |||||
| NetworkIO m_network_io; | NetworkIO m_network_io; | ||||
| std::unique_ptr<NetworkImplBase> m_impl; | std::unique_ptr<NetworkImplBase> m_impl; | ||||
| std::string m_extra_info; | std::string m_extra_info; | ||||
| @@ -113,6 +113,17 @@ typedef struct LiteConfig { | |||||
| //! get default config | //! get default config | ||||
| LITE_API LiteConfig* default_config(); | LITE_API LiteConfig* default_config(); | ||||
| /*! | |||||
| * \brief Exetra Configuration for a network | |||||
| * | |||||
| * \param disable_configure_by_model_info disable the configuration dumped with model, | |||||
| * if set true, all configuration in the model will not apply, users should configure | |||||
| * the network. | |||||
| */ | |||||
| typedef struct LiteExtraConfig { | |||||
| int disable_configure_by_model_info; | |||||
| } LiteExtraConfig; | |||||
| /*! | /*! | ||||
| * \brief config the network input and output item | * \brief config the network input and output item | ||||
| * | * | ||||
| @@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory( | |||||
| const void* model_mem, size_t size, const LiteConfig config, | const void* model_mem, size_t size, const LiteConfig config, | ||||
| LiteNetworkIO* ios); | LiteNetworkIO* ios); | ||||
| /** @brief the extra configuration | |||||
| * | |||||
| * @param extra_config the extra configuration to set into the network | |||||
| */ | |||||
| LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||||
| return innner_io; | return innner_io; | ||||
| } | } | ||||
| lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) { | |||||
| lite::ExtraConfig ret; | |||||
| ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info; | |||||
| return ret; | |||||
| } | |||||
| int LITE_make_default_network(LiteNetwork* network) { | int LITE_make_default_network(LiteNetwork* network) { | ||||
| LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
| @@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory( | |||||
| LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
| } | } | ||||
| LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
| static_cast<lite::Network*>(network)->extra_configure( | |||||
| convert_extra_config(extra_config)); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -134,6 +134,31 @@ class LiteConfig(Structure): | |||||
| return data.__repr__() | return data.__repr__() | ||||
| class LiteExtraConfig(Structure): | |||||
| """ | |||||
| Extra configuration when load and compile the graph | |||||
| disable_configure_by_model_info: disable the configuration dumped with | |||||
| model, if set true, all configuration in the model will not apply, users | |||||
| should configure the network. | |||||
| """ | |||||
| _fields_ = [ | |||||
| ("disable_configure_by_model_info", c_int), | |||||
| ] | |||||
| def __init__(self, disable_model_config=False): | |||||
| self.disable_configure_by_model_info = disable_model_config | |||||
| def __repr__(self): | |||||
| data = { | |||||
| "disable_configure_by_model_info": bool( | |||||
| self.disable_configure_by_model_info | |||||
| ), | |||||
| } | |||||
| return data.__repr__() | |||||
| class LiteIO(Structure): | class LiteIO(Structure): | ||||
| """ | """ | ||||
| config the network input and output item | config the network input and output item | ||||
| @@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase): | |||||
| "LITE_get_model_io_info_by_memory", | "LITE_get_model_io_info_by_memory", | ||||
| [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | ||||
| ), | ), | ||||
| ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), | |||||
| ] | ] | ||||
| @@ -541,6 +567,12 @@ class LiteNetwork(object): | |||||
| ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] | ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] | ||||
| return ret_name | return ret_name | ||||
| def extra_configure(self, extra_config): | |||||
| """ | |||||
| Extra Configuration to the network. | |||||
| """ | |||||
| self._api.LITE_extra_configure(self._network, extra_config) | |||||
| def share_weights_with(self, src_network): | def share_weights_with(self, src_network): | ||||
| """ | """ | ||||
| share weights with the loaded network | share weights with the loaded network | ||||
| @@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet): | |||||
| network.load(model_path) | network.load(model_path) | ||||
| self.do_forward(network) | self.do_forward(network) | ||||
| def test_disable_model_config(self): | |||||
| model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite") | |||||
| network = LiteNetwork() | |||||
| network.extra_configure(LiteExtraConfig(True)) | |||||
| network.load(model_path) | |||||
| self.do_forward(network) | |||||
| def test_pack_cache_to_model(self): | def test_pack_cache_to_model(self): | ||||
| model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | ||||
| network = LiteNetwork() | network = LiteNetwork() | ||||
| @@ -31,7 +31,6 @@ using namespace mgb; | |||||
| LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); | LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); | ||||
| void NetworkImplDft::set_config(const Config& config) { | void NetworkImplDft::set_config(const Config& config) { | ||||
| m_user_config = std::make_unique<Config>(); | |||||
| *m_user_config = config; | *m_user_config = config; | ||||
| m_compnode_locator = to_compnode_locator(m_user_config->device_type); | m_compnode_locator = to_compnode_locator(m_user_config->device_type); | ||||
| m_compnode_locator.device = config.device_id; | m_compnode_locator.device = config.device_id; | ||||
| @@ -428,8 +427,11 @@ void NetworkImplDft::load_model( | |||||
| global_layout_transform(); | global_layout_transform(); | ||||
| //! some optimization option maybe invalid in some case, so here just | |||||
| //! auto determine whether some options will apply. | |||||
| adapt_option_valid(); | adapt_option_valid(); | ||||
| //! find how many compnode the model has, this should call before update_io | |||||
| cross_compnode_model_detect(); | cross_compnode_model_detect(); | ||||
| //! update the IO of the network | //! update the IO of the network | ||||
| @@ -496,7 +498,6 @@ void NetworkImplDft::finish() const { | |||||
| } | } | ||||
| void NetworkImplDft::set_io(const NetworkIO& network_io) { | void NetworkImplDft::set_io(const NetworkIO& network_io) { | ||||
| m_network_io = std::make_unique<NetworkIOInner>(); | |||||
| for (auto&& in : network_io.inputs) { | for (auto&& in : network_io.inputs) { | ||||
| m_network_io->inputs.emplace_back(in); | m_network_io->inputs.emplace_back(in); | ||||
| } | } | ||||
| @@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase { | |||||
| LITE_DYN_TYPE_OBJ_FINAL_DECL; | LITE_DYN_TYPE_OBJ_FINAL_DECL; | ||||
| public: | public: | ||||
| NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } | |||||
| NetworkImplDft() { | |||||
| m_load_config.comp_graph = mgb::ComputingGraph::make(); | |||||
| m_user_config = std::make_unique<Config>(); | |||||
| m_network_io = std::make_unique<NetworkIOInner>(); | |||||
| } | |||||
| using S = megdnn::param::ExecutionPolicy::Strategy; | using S = megdnn::param::ExecutionPolicy::Strategy; | ||||
| using Var = mgb::cg::SymbolVar; | using Var = mgb::cg::SymbolVar; | ||||
| //! set the config of the network, include: | //! set the config of the network, include: | ||||
| @@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) { | |||||
| ModelParser model_parser(model_data, size); | ModelParser model_parser(model_data, size); | ||||
| //! parse the model info | //! parse the model info | ||||
| if (model_parser.parse_model_info( | if (model_parser.parse_model_info( | ||||
| m_config, m_network_io, separate_config_map, m_extra_info)) { | |||||
| m_config, m_network_io, separate_config_map, m_extra_info, | |||||
| !m_extra_config.disable_configure_by_model_info)) { | |||||
| if (m_config.backend == LiteBackend::LITE_DEFAULT && | if (m_config.backend == LiteBackend::LITE_DEFAULT && | ||||
| m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { | m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { | ||||
| m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( | m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( | ||||
| "parse_model")); | "parse_model")); | ||||
| } | } | ||||
| m_impl->set_config(m_config); | |||||
| m_impl->set_io(m_network_io); | |||||
| if (!m_extra_config.disable_configure_by_model_info) { | |||||
| m_impl->set_config(m_config); | |||||
| m_impl->set_io(m_network_io); | |||||
| } | |||||
| } | } | ||||
| //! decryption the model | //! decryption the model | ||||
| size_t model_length; | size_t model_length; | ||||
| @@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const { | |||||
| LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
| } | } | ||||
| void Network::extra_configure(const ExtraConfig& extra_config) { | |||||
| LITE_ERROR_HANDLER_BEGIN | |||||
| if (!extra_config.disable_configure_by_model_info) { | |||||
| LITE_ASSERT( | |||||
| !m_loaded, | |||||
| "disable_configure_by_model_info should be configured before model " | |||||
| "loaded."); | |||||
| } | |||||
| m_extra_config = extra_config; | |||||
| LITE_ERROR_HANDLER_END | |||||
| } | |||||
| /*********************** MGE special network function ***************/ | /*********************** MGE special network function ***************/ | ||||
| void Runtime::set_cpu_threads_number( | void Runtime::set_cpu_threads_number( | ||||
| @@ -43,7 +43,7 @@ void ModelParser::parse_header() { | |||||
| bool ModelParser::parse_model_info( | bool ModelParser::parse_model_info( | ||||
| Config& network_config, NetworkIO& network_io, | Config& network_config, NetworkIO& network_io, | ||||
| std::unordered_map<std::string, LiteAny>& isolated_config_map, | std::unordered_map<std::string, LiteAny>& isolated_config_map, | ||||
| std::string& extra_info) const { | |||||
| std::string& extra_info, bool configure_valid) const { | |||||
| //! no model info, no parse, direct return | //! no model info, no parse, direct return | ||||
| if (m_is_bare_model || !m_info) { | if (m_is_bare_model || !m_info) { | ||||
| return false; | return false; | ||||
| @@ -78,7 +78,7 @@ bool ModelParser::parse_model_info( | |||||
| } | } | ||||
| } | } | ||||
| //! parse ModelInfo::algo_policy | //! parse ModelInfo::algo_policy | ||||
| if (m_info->algo_policy()) { | |||||
| if (m_info->algo_policy() && configure_valid) { | |||||
| size_t cache_length = m_info->algo_policy()->size(); | size_t cache_length = m_info->algo_policy()->size(); | ||||
| const uint8_t* cache = m_info->algo_policy()->Data(); | const uint8_t* cache = m_info->algo_policy()->Data(); | ||||
| if (m_info_cache_parse_func_name == "LITE_parse_cache") { | if (m_info_cache_parse_func_name == "LITE_parse_cache") { | ||||
| @@ -93,6 +93,10 @@ bool ModelParser::parse_model_info( | |||||
| } else { | } else { | ||||
| LITE_THROW("opencl binary cache is not given"); | LITE_THROW("opencl binary cache is not given"); | ||||
| } | } | ||||
| } else { | |||||
| LITE_THROW(ssprintf( | |||||
| "model cache parse function of %s is not defined.", | |||||
| m_info_cache_parse_func_name.c_str())); | |||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -25,7 +25,7 @@ public: | |||||
| bool parse_model_info( | bool parse_model_info( | ||||
| Config& network_config, NetworkIO& network_io, | Config& network_config, NetworkIO& network_io, | ||||
| std::unordered_map<std::string, LiteAny>& isolated_config_map, | std::unordered_map<std::string, LiteAny>& isolated_config_map, | ||||
| std::string& extra_info) const; | |||||
| std::string& extra_info, bool configure_valid) const; | |||||
| //! parse the model and decrypt the model | //! parse the model and decrypt the model | ||||
| std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; | std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; | ||||
| @@ -7,6 +7,8 @@ | |||||
| #include "lite/global.h" | #include "lite/global.h" | ||||
| #include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
| #include "megbrain/utils/infile_persistent_cache.h" | |||||
| #include "megbrain/utils/persistent_cache.h" | |||||
| #include "test_common.h" | #include "test_common.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| @@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) { | |||||
| compare_lite_tensor<float>(output_tensor, result_mgb); | compare_lite_tensor<float>(output_tensor, result_mgb); | ||||
| } | } | ||||
| TEST(TestNetWorkOptions, DisableModelInfo) { | |||||
| //! clear the cache set by other test | |||||
| mgb::PersistentCache::inst().set_impl( | |||||
| std::make_shared<mgb::InMemoryPersistentCache>()); | |||||
| Config config; | |||||
| auto tensor = get_input_data("./input_data.npy"); | |||||
| std::string model_path = "./test_pack_cache_to_model.lite"; | |||||
| std::string model_path2 = "./test_pack_cache_to_model.lite"; | |||||
| std::string input_name = "data"; | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
| network->extra_configure({true}); | |||||
| Runtime::set_cpu_inplace_mode(network); | |||||
| network->load_model(model_path); | |||||
| //! the fast-run cache will not configure, so it is not support dump | |||||
| ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false); | |||||
| ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true); | |||||
| std::shared_ptr<Network> network2 = std::make_shared<Network>(config); | |||||
| network2->load_model(model_path2); | |||||
| //! the fast-run cache is configured by the model information | |||||
| ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true); | |||||
| } | |||||
| TEST(TestNetWorkOptions, FastRunIgnorBatch) { | TEST(TestNetWorkOptions, FastRunIgnorBatch) { | ||||
| Config config; | Config config; | ||||
| auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||