GitOrigin-RevId: cd155a1fcf
tags/v1.10.0
| @@ -117,6 +117,17 @@ struct LITE_API Config { | |||
| Options options = {}; | |||
| }; | |||
| /*! | |||
| * \brief Extra Configuration for a network | |||
| * | |||
| * \param disable_configure_by_model_info disable the configuration dumped with model, | |||
| * if set true, all configuration in the model will not apply, users should configure | |||
| * the network. | |||
| */ | |||
| struct LITE_API ExtraConfig { | |||
| bool disable_configure_by_model_info = false; | |||
| }; | |||
| /*! | |||
| * \brief config the network input and output item | |||
| * | |||
| @@ -275,6 +286,12 @@ public: | |||
| //! get static peak memory info showed by Graph visualization | |||
| void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; | |||
| /** @brief the extra configuration | |||
| * | |||
| * @param extra_config the extra configuration to set into the network | |||
| */ | |||
| void extra_configure(const ExtraConfig& extra_config); | |||
| public: | |||
| friend class NetworkHelper; | |||
| @@ -288,6 +305,7 @@ private: | |||
| private: | |||
| bool m_loaded = false; | |||
| Config m_config; | |||
| ExtraConfig m_extra_config; | |||
| NetworkIO m_network_io; | |||
| std::unique_ptr<NetworkImplBase> m_impl; | |||
| std::string m_extra_info; | |||
| @@ -113,6 +113,17 @@ typedef struct LiteConfig { | |||
| //! get default config | |||
| LITE_API LiteConfig* default_config(); | |||
| /*! | |||
| * \brief Exetra Configuration for a network | |||
| * | |||
| * \param disable_configure_by_model_info disable the configuration dumped with model, | |||
| * if set true, all configuration in the model will not apply, users should configure | |||
| * the network. | |||
| */ | |||
| typedef struct LiteExtraConfig { | |||
| int disable_configure_by_model_info; | |||
| } LiteExtraConfig; | |||
| /*! | |||
| * \brief config the network input and output item | |||
| * | |||
| @@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory( | |||
| const void* model_mem, size_t size, const LiteConfig config, | |||
| LiteNetworkIO* ios); | |||
| /** @brief the extra configuration | |||
| * | |||
| * @param extra_config the extra configuration to set into the network | |||
| */ | |||
| LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { | |||
| return innner_io; | |||
| } | |||
| lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) { | |||
| lite::ExtraConfig ret; | |||
| ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info; | |||
| return ret; | |||
| } | |||
| int LITE_make_default_network(LiteNetwork* network) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
| @@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory( | |||
| LITE_CAPI_END(); | |||
| } | |||
| LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
| static_cast<lite::Network*>(network)->extra_configure( | |||
| convert_extra_config(extra_config)); | |||
| LITE_CAPI_END(); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -134,6 +134,31 @@ class LiteConfig(Structure): | |||
| return data.__repr__() | |||
| class LiteExtraConfig(Structure): | |||
| """ | |||
| Extra configuration when load and compile the graph | |||
| disable_configure_by_model_info: disable the configuration dumped with | |||
| model, if set true, all configuration in the model will not apply, users | |||
| should configure the network. | |||
| """ | |||
| _fields_ = [ | |||
| ("disable_configure_by_model_info", c_int), | |||
| ] | |||
| def __init__(self, disable_model_config=False): | |||
| self.disable_configure_by_model_info = disable_model_config | |||
| def __repr__(self): | |||
| data = { | |||
| "disable_configure_by_model_info": bool( | |||
| self.disable_configure_by_model_info | |||
| ), | |||
| } | |||
| return data.__repr__() | |||
| class LiteIO(Structure): | |||
| """ | |||
| config the network input and output item | |||
| @@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase): | |||
| "LITE_get_model_io_info_by_memory", | |||
| [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], | |||
| ), | |||
| ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), | |||
| ] | |||
| @@ -541,6 +567,12 @@ class LiteNetwork(object): | |||
| ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] | |||
| return ret_name | |||
| def extra_configure(self, extra_config): | |||
| """ | |||
| Extra Configuration to the network. | |||
| """ | |||
| self._api.LITE_extra_configure(self._network, extra_config) | |||
| def share_weights_with(self, src_network): | |||
| """ | |||
| share weights with the loaded network | |||
| @@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet): | |||
| network.load(model_path) | |||
| self.do_forward(network) | |||
| def test_disable_model_config(self): | |||
| model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite") | |||
| network = LiteNetwork() | |||
| network.extra_configure(LiteExtraConfig(True)) | |||
| network.load(model_path) | |||
| self.do_forward(network) | |||
| def test_pack_cache_to_model(self): | |||
| model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | |||
| network = LiteNetwork() | |||
| @@ -31,7 +31,6 @@ using namespace mgb; | |||
| LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); | |||
| void NetworkImplDft::set_config(const Config& config) { | |||
| m_user_config = std::make_unique<Config>(); | |||
| *m_user_config = config; | |||
| m_compnode_locator = to_compnode_locator(m_user_config->device_type); | |||
| m_compnode_locator.device = config.device_id; | |||
| @@ -428,8 +427,11 @@ void NetworkImplDft::load_model( | |||
| global_layout_transform(); | |||
| //! some optimization option maybe invalid in some case, so here just | |||
| //! auto determine whether some options will apply. | |||
| adapt_option_valid(); | |||
| //! find how many compnode the model has, this should call before update_io | |||
| cross_compnode_model_detect(); | |||
| //! update the IO of the network | |||
| @@ -496,7 +498,6 @@ void NetworkImplDft::finish() const { | |||
| } | |||
| void NetworkImplDft::set_io(const NetworkIO& network_io) { | |||
| m_network_io = std::make_unique<NetworkIOInner>(); | |||
| for (auto&& in : network_io.inputs) { | |||
| m_network_io->inputs.emplace_back(in); | |||
| } | |||
| @@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase { | |||
| LITE_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } | |||
| NetworkImplDft() { | |||
| m_load_config.comp_graph = mgb::ComputingGraph::make(); | |||
| m_user_config = std::make_unique<Config>(); | |||
| m_network_io = std::make_unique<NetworkIOInner>(); | |||
| } | |||
| using S = megdnn::param::ExecutionPolicy::Strategy; | |||
| using Var = mgb::cg::SymbolVar; | |||
| //! set the config of the network, include: | |||
| @@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) { | |||
| ModelParser model_parser(model_data, size); | |||
| //! parse the model info | |||
| if (model_parser.parse_model_info( | |||
| m_config, m_network_io, separate_config_map, m_extra_info)) { | |||
| m_config, m_network_io, separate_config_map, m_extra_info, | |||
| !m_extra_config.disable_configure_by_model_info)) { | |||
| if (m_config.backend == LiteBackend::LITE_DEFAULT && | |||
| m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { | |||
| m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( | |||
| "parse_model")); | |||
| } | |||
| m_impl->set_config(m_config); | |||
| m_impl->set_io(m_network_io); | |||
| if (!m_extra_config.disable_configure_by_model_info) { | |||
| m_impl->set_config(m_config); | |||
| m_impl->set_io(m_network_io); | |||
| } | |||
| } | |||
| //! decryption the model | |||
| size_t model_length; | |||
| @@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const { | |||
| LITE_ERROR_HANDLER_END | |||
| } | |||
| void Network::extra_configure(const ExtraConfig& extra_config) { | |||
| LITE_ERROR_HANDLER_BEGIN | |||
| if (!extra_config.disable_configure_by_model_info) { | |||
| LITE_ASSERT( | |||
| !m_loaded, | |||
| "disable_configure_by_model_info should be configured before model " | |||
| "loaded."); | |||
| } | |||
| m_extra_config = extra_config; | |||
| LITE_ERROR_HANDLER_END | |||
| } | |||
| /*********************** MGE special network function ***************/ | |||
| void Runtime::set_cpu_threads_number( | |||
| @@ -43,7 +43,7 @@ void ModelParser::parse_header() { | |||
| bool ModelParser::parse_model_info( | |||
| Config& network_config, NetworkIO& network_io, | |||
| std::unordered_map<std::string, LiteAny>& isolated_config_map, | |||
| std::string& extra_info) const { | |||
| std::string& extra_info, bool configure_valid) const { | |||
| //! no model info, no parse, direct return | |||
| if (m_is_bare_model || !m_info) { | |||
| return false; | |||
| @@ -78,7 +78,7 @@ bool ModelParser::parse_model_info( | |||
| } | |||
| } | |||
| //! parse ModelInfo::algo_policy | |||
| if (m_info->algo_policy()) { | |||
| if (m_info->algo_policy() && configure_valid) { | |||
| size_t cache_length = m_info->algo_policy()->size(); | |||
| const uint8_t* cache = m_info->algo_policy()->Data(); | |||
| if (m_info_cache_parse_func_name == "LITE_parse_cache") { | |||
| @@ -93,6 +93,10 @@ bool ModelParser::parse_model_info( | |||
| } else { | |||
| LITE_THROW("opencl binary cache is not given"); | |||
| } | |||
| } else { | |||
| LITE_THROW(ssprintf( | |||
| "model cache parse function of %s is not defined.", | |||
| m_info_cache_parse_func_name.c_str())); | |||
| } | |||
| } | |||
| return true; | |||
| @@ -25,7 +25,7 @@ public: | |||
| bool parse_model_info( | |||
| Config& network_config, NetworkIO& network_io, | |||
| std::unordered_map<std::string, LiteAny>& isolated_config_map, | |||
| std::string& extra_info) const; | |||
| std::string& extra_info, bool configure_valid) const; | |||
| //! parse the model and decrypt the model | |||
| std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; | |||
| @@ -7,6 +7,8 @@ | |||
| #include "lite/global.h" | |||
| #include "megbrain/tensor.h" | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #include "test_common.h" | |||
| #include <string.h> | |||
| @@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) { | |||
| compare_lite_tensor<float>(output_tensor, result_mgb); | |||
| } | |||
| TEST(TestNetWorkOptions, DisableModelInfo) { | |||
| //! clear the cache set by other test | |||
| mgb::PersistentCache::inst().set_impl( | |||
| std::make_shared<mgb::InMemoryPersistentCache>()); | |||
| Config config; | |||
| auto tensor = get_input_data("./input_data.npy"); | |||
| std::string model_path = "./test_pack_cache_to_model.lite"; | |||
| std::string model_path2 = "./test_pack_cache_to_model.lite"; | |||
| std::string input_name = "data"; | |||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||
| network->extra_configure({true}); | |||
| Runtime::set_cpu_inplace_mode(network); | |||
| network->load_model(model_path); | |||
| //! the fast-run cache will not configure, so it is not support dump | |||
| ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false); | |||
| ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true); | |||
| std::shared_ptr<Network> network2 = std::make_shared<Network>(config); | |||
| network2->load_model(model_path2); | |||
| //! the fast-run cache is configured by the model information | |||
| ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true); | |||
| } | |||
| TEST(TestNetWorkOptions, FastRunIgnorBatch) { | |||
| Config config; | |||
| auto tensor = get_input_data("./input_data.npy"); | |||