| @@ -122,6 +122,11 @@ public: | |||||
| * these algorithms to speed up fastrun. | * these algorithms to speed up fastrun. | ||||
| * */ | * */ | ||||
| NAIVE = 1 << 1, | NAIVE = 1 << 1, | ||||
| /** | |||||
| * \brief whether the algo is usable once shape changed. | |||||
| * */ | |||||
| USABLE_DEPEND_ON_SHAPE = 1 << 2, | |||||
| }; | }; | ||||
| /** | /** | ||||
| @@ -35,7 +35,8 @@ public: | |||||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -146,7 +147,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -220,7 +222,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -235,7 +238,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| return "AARCH64_INT8X8X16_MK4_16X12X4"; | return "AARCH64_INT8X8X16_MK4_16X12X4"; | ||||
| @@ -253,7 +257,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| return "AARCH64_INT8X8X16_MK4_K8X8X8"; | return "AARCH64_INT8X8X16_MK4_K8X8X8"; | ||||
| @@ -271,7 +276,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -330,7 +336,8 @@ public: | |||||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -34,7 +34,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -50,7 +51,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -67,7 +69,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -102,7 +105,8 @@ public: | |||||
| class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -35,7 +35,8 @@ public: | |||||
| class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -224,7 +225,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -266,7 +268,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -18,7 +18,8 @@ using namespace megdnn; | |||||
| #define FOREACH_ALGO_ATTRIBUTE(cb) \ | #define FOREACH_ALGO_ATTRIBUTE(cb) \ | ||||
| cb(DEFAULT) \ | cb(DEFAULT) \ | ||||
| cb(REPRODUCIBLE) \ | cb(REPRODUCIBLE) \ | ||||
| cb(NAIVE) | |||||
| cb(NAIVE) \ | |||||
| cb(USABLE_DEPEND_ON_SHAPE) | |||||
| namespace { | namespace { | ||||
| inline const char* attr_str(const AlgoAttribute& attr) { | inline const char* attr_str(const AlgoAttribute& attr) { | ||||
| @@ -184,7 +184,8 @@ public: | |||||
| const char* name() const override { return "CHANNEL_WISE_SMALL"; } | const char* name() const override { return "CHANNEL_WISE_SMALL"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -89,7 +89,8 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| @@ -108,7 +109,8 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| @@ -114,7 +114,8 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -231,7 +232,8 @@ public: | |||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
| @@ -100,7 +100,8 @@ public: | |||||
| const char* name() const override { return "BLAS"; } | const char* name() const override { return "BLAS"; } | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | ||||
| }; | }; | ||||
| @@ -135,7 +135,8 @@ public: | |||||
| class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "X86_F32MK8_8X8"; } | const char* name() const override { return "X86_F32MK8_8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| //! return pair<positive_attr, negative_attr> | |||||
| std::pair<AlgoAttribute, AlgoAttribute> | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| const ExecutionStrategy& strategy) { | |||||
| std::pair<AlgoAttribute, AlgoAttribute> ret = | |||||
| std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
| if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
| ret.first |= AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| if (strategy & ExecutionStrategy::OPTIMIZED) { | |||||
| ret.second |= AlgoAttribute::NAIVE; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| return; | return; | ||||
| AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
| auto target_attr = | |||||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
| std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||||
| auto target_attr = ctx.extract_algo_attribute(selected_strategy); | |||||
| std::string layouts_str = | |||||
| format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||||
| double cur_timeout = 0; | double cur_timeout = 0; | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| @@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
| if (prof.empty()) | if (prof.empty()) | ||||
| return {}; | return {}; | ||||
| auto attr_from_strategy = | |||||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
| auto target_attr = extract_algo_attribute(selected_strategy); | |||||
| for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
| auto attr_of_algo = | auto attr_of_algo = | ||||
| static_cast<megdnn::Algorithm::Attribute>(i.attribute); | static_cast<megdnn::Algorithm::Attribute>(i.attribute); | ||||
| bool contain_attr_all_positive = | bool contain_attr_all_positive = | ||||
| (attr_from_strategy.first == | |||||
| (attr_of_algo & attr_from_strategy.first)); | |||||
| (target_attr.first == | |||||
| (attr_of_algo & target_attr.first)); | |||||
| bool contain_attr_any_negative = | bool contain_attr_any_negative = | ||||
| static_cast<bool>(attr_of_algo & attr_from_strategy.second); | |||||
| static_cast<bool>(attr_of_algo & target_attr.second); | |||||
| if (contain_attr_all_positive && !contain_attr_any_negative) { | if (contain_attr_all_positive && !contain_attr_any_negative) { | ||||
| auto iter = algo_map.find(i.algo); | auto iter = algo_map.find(i.algo); | ||||
| mgb_assert(iter != algo_map.end(), | mgb_assert(iter != algo_map.end(), | ||||
| @@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
| mgb_log_error( | mgb_log_error( | ||||
| "algos read from cache could not satisfy attribute with %s and " | "algos read from cache could not satisfy attribute with %s and " | ||||
| "without %s", | "without %s", | ||||
| Algorithm::attribute_str(attr_from_strategy.first).c_str(), | |||||
| Algorithm::attribute_str(attr_from_strategy.second).c_str()); | |||||
| Algorithm::attribute_str(target_attr.first).c_str(), | |||||
| Algorithm::attribute_str(target_attr.second).c_str()); | |||||
| mgb_trap(); | mgb_trap(); | ||||
| MIDOUT_E | MIDOUT_E | ||||
| @@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
| } | } | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| auto attr = | |||||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
| auto attr = extract_algo_attribute(selected_strategy); | |||||
| ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
| policy.algo = | policy.algo = | ||||
| APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
| @@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
| policy.algo = get_profile_result_from_cache(selected_strategy).desc; | policy.algo = get_profile_result_from_cache(selected_strategy).desc; | ||||
| if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
| auto target_attr = | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy); | |||||
| auto target_attr = extract_algo_attribute(selected_strategy); | |||||
| std::string layouts_str = | std::string layouts_str = | ||||
| format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | ||||
| std::string msg = ssprintf( | std::string msg = ssprintf( | ||||
| @@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| auto attr = extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy); | |||||
| auto attr = extract_algo_attribute(selected_strategy); | |||||
| policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
| args..., workspace_limit, attr.first, | args..., workspace_limit, attr.first, | ||||
| attr.second), | attr.second), | ||||
| @@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
| return result; | return result; | ||||
| } | } | ||||
| template <typename Opr> | |||||
| std::pair<AlgoAttribute, AlgoAttribute> | |||||
| AlgoChooser<Opr>::ExeContext::extract_algo_attribute( | |||||
| const ExecutionStrategy& strategy) const { | |||||
| std::pair<AlgoAttribute, AlgoAttribute> ret = | |||||
| std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
| //! from strategy | |||||
| if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
| ret.first |= AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| if (strategy & ExecutionStrategy::OPTMIZED) { | |||||
| ret.second |= AlgoAttribute::NAIVE; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| #define INST(Opr) \ | #define INST(Opr) \ | ||||
| template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | ||||
| const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | ||||
| @@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
| AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | ||||
| const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | ||||
| policy, \ | policy, \ | ||||
| double& timeout) const; | |||||
| double& timeout) const; \ | |||||
| template std::pair<AlgoAttribute, AlgoAttribute> \ | |||||
| AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \ | |||||
| const ExecutionStrategy& strategy) const; | |||||
| MGB_FOREACH_FASTRUN_OPR(INST) | MGB_FOREACH_FASTRUN_OPR(INST) | ||||
| @@ -149,6 +149,16 @@ public: | |||||
| ImplExecutionPolicy& policy, | ImplExecutionPolicy& policy, | ||||
| bool retrive_from_cache = true) const; | bool retrive_from_cache = true) const; | ||||
| /** | |||||
| * \brief extract algo attribute from execution strategy and graph | |||||
| * option. | |||||
| * | |||||
| * \param strategy select algo which matched this strategy | |||||
| * \return pair<positive_attr, negative_attr> | |||||
| */ | |||||
| std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute( | |||||
| const ExecutionStrategy& strategy) const; | |||||
| private: | private: | ||||
| Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | ||||
| }; | }; | ||||