| @@ -565,6 +565,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp | |||||
| choose_by_profile( | choose_by_profile( | ||||
| const ExecutionStrategy& selected_strategy, bool enable_update) const { | const ExecutionStrategy& selected_strategy, bool enable_update) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) | ||||
| // no_profiling_on_shape_change is usually false, no interface to change it easily | |||||
| if (m_desc.no_profiling_on_shape_change) { | if (m_desc.no_profiling_on_shape_change) { | ||||
| auto policy = m_dnn_opr->execution_policy(); | auto policy = m_dnn_opr->execution_policy(); | ||||
| if (policy.algo.valid()) { | if (policy.algo.valid()) { | ||||
| @@ -579,6 +580,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp | |||||
| } | } | ||||
| } | } | ||||
| // if update enabled, do profiling and update cache | |||||
| // enable_update = false only when using HEURISRIC_PROFILE strategy | |||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy tmp_policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy tmp_policy; | ||||
| bool retrive_from_cache = true; | bool retrive_from_cache = true; | ||||
| bool allow_log = false; | bool allow_log = false; | ||||
| @@ -604,6 +607,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp | |||||
| }); | }); | ||||
| } | } | ||||
| // try to retrive algorithm from fastrun cache, this time it's guaranteed to get | |||||
| // result, retrive_from_cache = true, allow_log = true | |||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | ||||
| construct_execution_policy(selected_strategy, policy); | construct_execution_policy(selected_strategy, policy); | ||||
| return policy; | return policy; | ||||
| @@ -623,13 +628,16 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | |||||
| m_incache_layouts.data(), m_incache_layouts.size(), &origin_param, | m_incache_layouts.data(), m_incache_layouts.size(), &origin_param, | ||||
| sizeof(origin_param)}; | sizeof(origin_param)}; | ||||
| auto&& rst = cache.get(cache_key); | auto&& rst = cache.get(cache_key); | ||||
| // failed to find a cache entry, return | |||||
| if (!rst.valid()) | if (!rst.valid()) | ||||
| return {{}, rst}; | return {{}, rst}; | ||||
| // found a cache entry(it's a vector of Result), but it's empty | |||||
| auto&& prof = rst.val(); | auto&& prof = rst.val(); | ||||
| if (prof.empty()) | if (prof.empty()) | ||||
| return {{}, rst}; | return {{}, rst}; | ||||
| // found non-empty cache result, filter it by workspace limit and attribute | |||||
| size_t workspace_limit = | size_t workspace_limit = | ||||
| m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit); | m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit); | ||||
| auto target_attr = extract_algo_attribute(selected_strategy); | auto target_attr = extract_algo_attribute(selected_strategy); | ||||
| @@ -644,6 +652,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | |||||
| if (contain_attr_all_positive) { | if (contain_attr_all_positive) { | ||||
| if (!contain_attr_any_negative) { | if (!contain_attr_any_negative) { | ||||
| if (i.workspace <= workspace_limit) { | if (i.workspace <= workspace_limit) { | ||||
| // found a well-suited algothrim with good workspace limit and | |||||
| // correct attribute | |||||
| Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); | Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); | ||||
| return {algo_desc, rst}; | return {algo_desc, rst}; | ||||
| } | } | ||||
| @@ -654,9 +664,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | |||||
| } | } | ||||
| } | } | ||||
| // failed to find an algorithm that satisfies the actual workspace limit | |||||
| if (skip_by_workspace) | if (skip_by_workspace) | ||||
| return {}; | return {}; | ||||
| // failed to find an algorithm that satisfies the actual attribute | |||||
| std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts); | std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts); | ||||
| if (skip_by_negative) { | if (skip_by_negative) { | ||||
| mgb_log_error( | mgb_log_error( | ||||
| @@ -685,9 +697,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, bool retrive_from_cache, | typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, bool retrive_from_cache, | ||||
| bool allow_log) const { | bool allow_log) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) | ||||
| // policy.algo is always invalid when called from choose_by_profile | |||||
| // policy.algo will be valid when called from profile | |||||
| if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
| if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
| policy.algo = get_profile_result_from_cache(selected_strategy).first; | policy.algo = get_profile_result_from_cache(selected_strategy).first; | ||||
| // nothing is found even with profiling | |||||
| if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
| if (allow_log) { | if (allow_log) { | ||||
| auto target_attr = extract_algo_attribute(selected_strategy); | auto target_attr = extract_algo_attribute(selected_strategy); | ||||
| @@ -710,6 +725,8 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| return; | return; | ||||
| } | } | ||||
| } else { | } else { | ||||
| // retrive_from_cache = false happens when using algo choose hook in | |||||
| // megbrain graph return heuristic algorithm in this case | |||||
| auto workspace_limit = m_desc.get_workspace_limit( | auto workspace_limit = m_desc.get_workspace_limit( | ||||
| m_cn, m_execution_policy.workspace_limit); | m_cn, m_execution_policy.workspace_limit); | ||||
| @@ -727,11 +744,13 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| } | } | ||||
| } | } | ||||
| // construct current algorithm | |||||
| Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | ||||
| mgb_assert(algo, "Unknown algo description"); | mgb_assert(algo, "Unknown algo description"); | ||||
| std::vector<Algorithm::SearchItem>&& sub_items = | std::vector<Algorithm::SearchItem>&& sub_items = | ||||
| algo->get_subopr_list(to_layout_array<Opr>(m_fastrun_layouts), m_dnn_opr); | algo->get_subopr_list(to_layout_array<Opr>(m_fastrun_layouts), m_dnn_opr); | ||||
| // construct sub oprs' algorithm | |||||
| FOREACH_OPR_TYPE_DISPATCH(sub_items, { | FOREACH_OPR_TYPE_DISPATCH(sub_items, { | ||||
| auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); | auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); | ||||
| megdnn_opr->param() = | megdnn_opr->param() = | ||||
| @@ -790,6 +809,8 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> AlgoChooser< | |||||
| auto heu = choose_by_heuristic(m_execution_policy.strategy); | auto heu = choose_by_heuristic(m_execution_policy.strategy); | ||||
| auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_fastrun_layouts); | auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_fastrun_layouts); | ||||
| bool found = false; | bool found = false; | ||||
| // make heuristic algorithm always the first in all candidate alrogrithms | |||||
| // so profiling step will always run heuristic algorithm first | |||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| if (ret[i].desc == heu.algo) { | if (ret[i].desc == heu.algo) { | ||||
| found = true; | found = true; | ||||
| @@ -798,6 +819,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> AlgoChooser< | |||||
| } | } | ||||
| } | } | ||||
| // make sure heuristic algorithm is valid | |||||
| Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo); | Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo); | ||||
| mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -813,6 +835,7 @@ template <typename Opr> | |||||
| Maybe<AlgoChooserProfileCache::ResultEntry> AlgoChooser<Opr>::AlgoChooserHelper:: | Maybe<AlgoChooserProfileCache::ResultEntry> AlgoChooser<Opr>::AlgoChooserHelper:: | ||||
| profile_single_algo(const ImplExecutionPolicy& policy, double& timeout) const { | profile_single_algo(const ImplExecutionPolicy& policy, double& timeout) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo"))) | ||||
| // fill TimedProfiler<Opr>::param and run actual timed profiler | |||||
| typename TimedProfiler<Opr>::Param param; | typename TimedProfiler<Opr>::Param param; | ||||
| // force check copy size <= dest len-1 from gcc8 for safe | // force check copy size <= dest len-1 from gcc8 for safe | ||||
| param.execution_policy = | param.execution_policy = | ||||
| @@ -867,7 +890,11 @@ template <typename Opr> | |||||
| void AlgoChooser<Opr>::AlgoChooserHelper::profile( | void AlgoChooser<Opr>::AlgoChooserHelper::profile( | ||||
| const ExecutionStrategy& selected_strategy) const { | const ExecutionStrategy& selected_strategy) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) | ||||
| // some sub oprs have beed profiled before | |||||
| // sub oprs won't be checked at the beginning of choose_by_profile | |||||
| auto&& rst = get_profile_result_from_cache(selected_strategy); | auto&& rst = get_profile_result_from_cache(selected_strategy); | ||||
| // rst.first.valid means there exists valid algorithms for current opr, just return | |||||
| // otherwise need to profile | |||||
| if (rst.first.valid()) | if (rst.first.valid()) | ||||
| return; | return; | ||||
| AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
| @@ -957,6 +984,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||||
| Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit); | Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit); | ||||
| mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | ||||
| // append some previous profiled results | |||||
| if (rst.second.valid()) | if (rst.second.valid()) | ||||
| prof_rst.insert( | prof_rst.insert( | ||||
| prof_rst.end(), rst.second.val().begin(), rst.second.val().end()); | prof_rst.end(), rst.second.val().begin(), rst.second.val().end()); | ||||