GitOrigin-RevId: 6046a2db0c
tags/v1.5.0
| @@ -25,79 +25,9 @@ | |||||
| using namespace mgb; | using namespace mgb; | ||||
| namespace { | |||||
| class InMemoryPersistentCache final: public PersistentCache { | |||||
| struct BlobStorage: public Blob { | |||||
| std::unique_ptr<uint8_t[]> data_refhold; | |||||
| size_t hash = 0; | |||||
| BlobStorage& init_data_ref(const Blob &b) { | |||||
| data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||||
| memcpy(data_refhold.get(), b.ptr, b.size); | |||||
| data_refhold.get()[b.size] = 0; // for C-string safety | |||||
| ptr = data_refhold.get(); | |||||
| size = b.size; | |||||
| return *this; | |||||
| } | |||||
| BlobStorage& init_hash() { | |||||
| hash = XXHash{}.update(ptr, size).digest(); | |||||
| return *this; | |||||
| } | |||||
| bool operator == (const BlobStorage &rhs) const { | |||||
| return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||||
| } | |||||
| struct Hash { | |||||
| size_t operator() (const BlobStorage &b) const { | |||||
| return b.hash; | |||||
| } | |||||
| }; | |||||
| }; | |||||
| std::unordered_map<std::string, | |||||
| std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
| m_cache; | |||||
| std::mutex m_mtx; | |||||
| Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
| decltype(m_cache.begin()) iter0; | |||||
| { | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| iter0 = m_cache.find(category); | |||||
| if (iter0 == m_cache.end()) | |||||
| return None; | |||||
| } | |||||
| BlobStorage key_storage; | |||||
| key_storage.Blob::operator=(key); | |||||
| key_storage.init_hash(); | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| auto iter1 = iter0->second.find(key_storage); | |||||
| if (iter1 == iter0->second.end()) | |||||
| return None; | |||||
| return iter1->second; | |||||
| } | |||||
| void put(const std::string& category, const Blob& key, | |||||
| const Blob& value) override { | |||||
| BlobStorage key_storage; | |||||
| key_storage.init_data_ref(key).init_hash(); | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| auto size0 = m_cache.size(); | |||||
| m_cache[category][std::move(key_storage)].init_data_ref(value); | |||||
| if (m_cache.size() > size0) { | |||||
| mgb_log_debug("new cache category: %s", category.c_str()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| } | |||||
| // ================= PersistentCache ====================== | |||||
| std::shared_ptr<PersistentCache> PersistentCache::sm_impl = | std::shared_ptr<PersistentCache> PersistentCache::sm_impl = | ||||
| std::make_shared<InMemoryPersistentCache>(); | |||||
| std::make_shared<InMemoryPersistentCache>(); | |||||
| std::shared_ptr<PersistentCache> PersistentCache::set_impl( | std::shared_ptr<PersistentCache> PersistentCache::set_impl( | ||||
| std::shared_ptr<PersistentCache> impl) { | std::shared_ptr<PersistentCache> impl) { | ||||
| @@ -141,6 +71,65 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { | |||||
| } | } | ||||
| } | } | ||||
| // ================= InMemoryPersistentCache ================== | |||||
| using Blob = PersistentCache::Blob; | |||||
| InMemoryPersistentCache::BlobStorage& | |||||
| InMemoryPersistentCache::BlobStorage::init_data_ref(const Blob& b) { | |||||
| data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||||
| memcpy(data_refhold.get(), b.ptr, b.size); | |||||
| data_refhold.get()[b.size] = 0; // for C-string safety | |||||
| ptr = data_refhold.get(); | |||||
| size = b.size; | |||||
| return *this; | |||||
| } | |||||
| InMemoryPersistentCache::BlobStorage& | |||||
| InMemoryPersistentCache::BlobStorage::init_hash() { | |||||
| hash = XXHash{}.update(ptr, size).digest(); | |||||
| return *this; | |||||
| } | |||||
| bool InMemoryPersistentCache::BlobStorage::operator==( | |||||
| const BlobStorage& rhs) const { | |||||
| return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||||
| } | |||||
| Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, | |||||
| const Blob& key) { | |||||
| decltype(m_cache.begin()) iter0; | |||||
| { | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| iter0 = m_cache.find(category); | |||||
| if (iter0 == m_cache.end()) | |||||
| return None; | |||||
| } | |||||
| BlobStorage key_storage; | |||||
| key_storage.Blob::operator=(key); | |||||
| key_storage.init_hash(); | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| auto iter1 = iter0->second.find(key_storage); | |||||
| if (iter1 == iter0->second.end()) | |||||
| return None; | |||||
| return iter1->second; | |||||
| } | |||||
| void InMemoryPersistentCache::put(const std::string& category, const Blob& key, | |||||
| const Blob& value) { | |||||
| BlobStorage key_storage; | |||||
| key_storage.init_data_ref(key).init_hash(); | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| auto size0 = m_cache.size(); | |||||
| m_cache[category][std::move(key_storage)].init_data_ref(value); | |||||
| if (m_cache.size() > size0) { | |||||
| mgb_log_debug("new cache category: %s", category.c_str()); | |||||
| } | |||||
| } | |||||
| // ================= AlgoChooserProfileCache ================== | |||||
| AlgoChooserProfileCache::AlgoChooserProfileCache( | AlgoChooserProfileCache::AlgoChooserProfileCache( | ||||
| CompNode cn, const char *opr_type) { | CompNode cn, const char *opr_type) { | ||||
| m_category = "profile:"; | m_category = "profile:"; | ||||
| @@ -55,6 +55,37 @@ namespace mgb { | |||||
| static std::string make_category_from_comp_node(CompNode comp_node); | static std::string make_category_from_comp_node(CompNode comp_node); | ||||
| }; | }; | ||||
| /*! | |||||
| * \brief persistent cache that keep in memory | |||||
| * The implementation is thread safe. | |||||
| */ | |||||
| class InMemoryPersistentCache final : public PersistentCache { | |||||
| struct BlobStorage : public PersistentCache::Blob { | |||||
| std::unique_ptr<uint8_t[]> data_refhold; | |||||
| size_t hash = 0; | |||||
| BlobStorage& init_data_ref(const Blob& b); | |||||
| BlobStorage& init_hash(); | |||||
| bool operator==(const BlobStorage& rhs) const; | |||||
| struct Hash { | |||||
| size_t operator()(const BlobStorage& b) const { return b.hash; } | |||||
| }; | |||||
| }; | |||||
| Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||||
| void put(const std::string& category, const Blob& key, | |||||
| const Blob& value) override; | |||||
| std::unordered_map< | |||||
| std::string, | |||||
| std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
| m_cache; | |||||
| std::mutex m_mtx; | |||||
| }; | |||||
| /*! | /*! | ||||
| * \brief proxy PersistentCache to be better suited for managing profiling | * \brief proxy PersistentCache to be better suited for managing profiling | ||||
| * results of operator impl algorithms | * results of operator impl algorithms | ||||
| @@ -68,7 +68,6 @@ std::string format_fixlayouts( | |||||
| ret.append(", "); | ret.append(", "); | ||||
| } | } | ||||
| ret.append(layouts[i].to_string() + " "); | ret.append(layouts[i].to_string() + " "); | ||||
| ret.append(layouts[i].dtype.name()); | |||||
| } | } | ||||
| ret.append(") -> ("); | ret.append(") -> ("); | ||||
| for (size_t i = 0; i < arity_out; ++i) { | for (size_t i = 0; i < arity_out; ++i) { | ||||
| @@ -76,7 +75,6 @@ std::string format_fixlayouts( | |||||
| ret.append(", "); | ret.append(", "); | ||||
| } | } | ||||
| ret.append(layouts[i + arity_in].to_string() + " "); | ret.append(layouts[i + arity_in].to_string() + " "); | ||||
| ret.append(layouts[i + arity_in].dtype.name()); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -420,6 +418,7 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||||
| AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | ||||
| }); | }); | ||||
| } | } | ||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | ||||
| ctx.construct_execution_policy(selected_strategy, policy); | ctx.construct_execution_policy(selected_strategy, policy); | ||||
| return policy; | return policy; | ||||
| @@ -660,8 +659,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| bool retrive_from_cache) const { | bool retrive_from_cache) const { | ||||
| if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
| if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
| policy.algo = | |||||
| get_profile_result_from_cache(selected_strategy).desc; | |||||
| policy.algo = get_profile_result_from_cache(selected_strategy).desc; | |||||
| if (!policy.algo.valid()) { | |||||
| auto target_attr = | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy); | |||||
| std::string layouts_str = | |||||
| format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | |||||
| std::string msg = ssprintf( | |||||
| "(mbg_opr : %s, layouts %s, with attribute(%s) and " | |||||
| "without attribute(%s)", | |||||
| m_base_mgb_opr->dyn_typeinfo()->name, | |||||
| layouts_str.c_str(), | |||||
| Algorithm::attribute_str(target_attr.first).c_str(), | |||||
| Algorithm::attribute_str(target_attr.second).c_str()); | |||||
| mgb_log_warn( | |||||
| "No algo get from cache for %s. This may caused by " | |||||
| "mismatch with model and cache file. ex. profiling " | |||||
| "with version1, but inferencing on version2 or " | |||||
| "profiling modelA but inferencing modelB", | |||||
| msg.c_str()); | |||||
| return; | |||||
| } | |||||
| } else { | } else { | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| @@ -673,10 +692,12 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| attr.second), | attr.second), | ||||
| m_layouts) | m_layouts) | ||||
| .desc; | .desc; | ||||
| mgb_assert(policy.algo.valid(), | |||||
| "No algo found from heuristic with strategy %u and " | |||||
| "workspace limit %zu", | |||||
| static_cast<uint32_t>(selected_strategy), | |||||
| workspace_limit); | |||||
| } | } | ||||
| mgb_assert(policy.algo.valid(), | |||||
| "No algo found from cache or heuristic, maybe some error " | |||||
| "occured"); | |||||
| } | } | ||||
| Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | ||||
| @@ -697,9 +718,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| sub_ctx.construct_execution_policy(selected_strategy, | sub_ctx.construct_execution_policy(selected_strategy, | ||||
| policy.sub_policy.back(), | policy.sub_policy.back(), | ||||
| retrive_from_cache); | retrive_from_cache); | ||||
| if (!policy.sub_policy.back().algo.valid()) { | |||||
| // means sub_ctx.construct_execution_policy fails. clean up | |||||
| // policy.algo and return | |||||
| policy = {}; | |||||
| return; | |||||
| } | |||||
| }); | }); | ||||
| return; | |||||
| } | } | ||||
| template <typename Opr> | template <typename Opr> | ||||
| @@ -140,9 +140,10 @@ public: | |||||
| * \brief construct execution policy from cache or heuristic. | * \brief construct execution policy from cache or heuristic. | ||||
| * | * | ||||
| * \param selected_strategy select algo which matched this strategy | * \param selected_strategy select algo which matched this strategy | ||||
| * \param policy execution policy | |||||
| * \param [out] policy execution policy | |||||
| * \param retrive_from_cache retrive algo from cache if set True, get | * \param retrive_from_cache retrive algo from cache if set True, get | ||||
| * from heuristic otherwise. | * from heuristic otherwise. | ||||
| * \note When contruction fail, the policy will be cleaned. | |||||
| */ | */ | ||||
| void construct_execution_policy(ExecutionStrategy selected_strategy, | void construct_execution_policy(ExecutionStrategy selected_strategy, | ||||
| ImplExecutionPolicy& policy, | ImplExecutionPolicy& policy, | ||||
| @@ -152,14 +153,13 @@ public: | |||||
| Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | ||||
| }; | }; | ||||
| template<typename U> | |||||
| template <typename U> | |||||
| friend class AlgoChooser; | friend class AlgoChooser; | ||||
| private: | private: | ||||
| //! entrance for getting algorithm according to execution strategy | //! entrance for getting algorithm according to execution strategy | ||||
| static ImplExecutionPolicy get_policy(ExeContext& ctx); | static ImplExecutionPolicy get_policy(ExeContext& ctx); | ||||
| //! profile and save to cache | //! profile and save to cache | ||||
| static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | ||||
| @@ -30,7 +30,6 @@ | |||||
| #include <random> | #include <random> | ||||
| using namespace mgb; | using namespace mgb; | ||||
| namespace { | namespace { | ||||
| using Param = opr::Convolution::Param; | using Param = opr::Convolution::Param; | ||||
| @@ -354,21 +353,26 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
| auto cn = CompNode::load("cpux"); | auto cn = CompNode::load("cpux"); | ||||
| auto orig_impl = PersistentCache::set_impl( | |||||
| std::make_shared<InMemoryPersistentCache>()); | |||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| for (auto strategy : | for (auto strategy : | ||||
| SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | ||||
| S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||||
| S::PROFILE | S::HEURISTIC}) { | |||||
| #else | #else | ||||
| for (auto strategy : | for (auto strategy : | ||||
| SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | ||||
| #endif | #endif | ||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | auto mkvar = [&](const char* name, const TensorShape& shp, | ||||
| const DType& dtype) { | const DType& dtype) { | ||||
| return opr::TypeCvt::make( | return opr::TypeCvt::make( | ||||
| opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||||
| .rename(name), | |||||
| dtype); | dtype); | ||||
| }; | }; | ||||
| @@ -388,7 +392,11 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
| HostTensorND host_y; | HostTensorND host_y; | ||||
| auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); | auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); | ||||
| func->execute(); | func->execute(); | ||||
| //! set a new cache | |||||
| PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); | |||||
| } | } | ||||
| PersistentCache::set_impl(orig_impl); | |||||
| } | } | ||||
| TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | ||||
| @@ -401,19 +409,21 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | |||||
| for (auto strategy : | for (auto strategy : | ||||
| SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { | SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { | ||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto mkvar = [&](const char* name, const TensorShape& shp, | auto mkvar = [&](const char* name, const TensorShape& shp, | ||||
| const DType& dtype) { | const DType& dtype) { | ||||
| return opr::TypeCvt::make( | return opr::TypeCvt::make( | ||||
| opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), | |||||
| opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||||
| .rename(name), | |||||
| dtype); | dtype); | ||||
| }; | }; | ||||
| auto x = mkvar("x", {20, 50, 50, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||||
| auto w = mkvar("w", {24, 3, 3, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||||
| auto x = mkvar("x", {20, 50, 50, 16}, | |||||
| dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||||
| auto w = mkvar("w", {24, 3, 3, 16}, | |||||
| dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||||
| auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f)); | auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f)); | ||||
| param.nonlineMode = Param::NonlineMode::RELU; | param.nonlineMode = Param::NonlineMode::RELU; | ||||