| @@ -122,6 +122,11 @@ public: | |||
| * these algorithms to speed up fastrun. | |||
| * */ | |||
| NAIVE = 1 << 1, | |||
| /** | |||
| * \brief whether the algo is usable once shape changed. | |||
| * */ | |||
| USABLE_DEPEND_ON_SHAPE = 1 << 2, | |||
| }; | |||
| /** | |||
| @@ -35,7 +35,8 @@ public: | |||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -146,7 +147,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -220,7 +222,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -235,7 +238,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||
| @@ -253,7 +257,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
| @@ -271,7 +276,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -330,7 +336,8 @@ public: | |||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -34,7 +34,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -50,7 +51,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -67,7 +69,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -102,7 +105,8 @@ public: | |||
| class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -35,7 +35,8 @@ public: | |||
| class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -224,7 +225,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -266,7 +268,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -18,7 +18,8 @@ using namespace megdnn; | |||
| #define FOREACH_ALGO_ATTRIBUTE(cb) \ | |||
| cb(DEFAULT) \ | |||
| cb(REPRODUCIBLE) \ | |||
| cb(NAIVE) | |||
| cb(NAIVE) \ | |||
| cb(USABLE_DEPEND_ON_SHAPE) | |||
| namespace { | |||
| inline const char* attr_str(const AlgoAttribute& attr) { | |||
| @@ -184,7 +184,8 @@ public: | |||
| const char* name() const override { return "CHANNEL_WISE_SMALL"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| }; | |||
| @@ -89,7 +89,8 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| @@ -108,7 +109,8 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| @@ -114,7 +114,8 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| }; | |||
| @@ -231,7 +232,8 @@ public: | |||
| const char* name() const override { return m_name.c_str(); } | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | |||
| @@ -100,7 +100,8 @@ public: | |||
| const char* name() const override { return "BLAS"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
| }; | |||
| @@ -135,7 +135,8 @@ public: | |||
| class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "X86_F32MK8_8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||
| return ret; | |||
| } | |||
| //! return pair<positive_attr, negative_attr> | |||
| std::pair<AlgoAttribute, AlgoAttribute> | |||
| extract_algo_attribute_from_execution_strategy( | |||
| const ExecutionStrategy& strategy) { | |||
| std::pair<AlgoAttribute, AlgoAttribute> ret = | |||
| std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||
| if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
| ret.first |= AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| if (strategy & ExecutionStrategy::OPTIMIZED) { | |||
| ret.second |= AlgoAttribute::NAIVE; | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace | |||
| namespace mgb { | |||
| @@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
| return; | |||
| AlgoChooserProfileCache::Result prof_rst; | |||
| auto target_attr = | |||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
| std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||
| auto target_attr = ctx.extract_algo_attribute(selected_strategy); | |||
| std::string layouts_str = | |||
| format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||
| double cur_timeout = 0; | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| @@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
| if (prof.empty()) | |||
| return {}; | |||
| auto attr_from_strategy = | |||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
| auto target_attr = extract_algo_attribute(selected_strategy); | |||
| for (auto&& i : prof) { | |||
| auto attr_of_algo = | |||
| static_cast<megdnn::Algorithm::Attribute>(i.attribute); | |||
| bool contain_attr_all_positive = | |||
| (attr_from_strategy.first == | |||
| (attr_of_algo & attr_from_strategy.first)); | |||
| (target_attr.first == | |||
| (attr_of_algo & target_attr.first)); | |||
| bool contain_attr_any_negative = | |||
| static_cast<bool>(attr_of_algo & attr_from_strategy.second); | |||
| static_cast<bool>(attr_of_algo & target_attr.second); | |||
| if (contain_attr_all_positive && !contain_attr_any_negative) { | |||
| auto iter = algo_map.find(i.algo); | |||
| mgb_assert(iter != algo_map.end(), | |||
| @@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
| mgb_log_error( | |||
| "algos read from cache could not satisfy attribute with %s and " | |||
| "without %s", | |||
| Algorithm::attribute_str(attr_from_strategy.first).c_str(), | |||
| Algorithm::attribute_str(attr_from_strategy.second).c_str()); | |||
| Algorithm::attribute_str(target_attr.first).c_str(), | |||
| Algorithm::attribute_str(target_attr.second).c_str()); | |||
| mgb_trap(); | |||
| MIDOUT_E | |||
| @@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
| } | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
| auto attr = | |||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
| auto attr = extract_algo_attribute(selected_strategy); | |||
| ImplExecutionPolicy policy; | |||
| policy.algo = | |||
| APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||
| @@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
| if (retrive_from_cache) { | |||
| policy.algo = get_profile_result_from_cache(selected_strategy).desc; | |||
| if (!policy.algo.valid()) { | |||
| auto target_attr = | |||
| extract_algo_attribute_from_execution_strategy( | |||
| selected_strategy); | |||
| auto target_attr = extract_algo_attribute(selected_strategy); | |||
| std::string layouts_str = | |||
| format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | |||
| std::string msg = ssprintf( | |||
| @@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
| auto attr = extract_algo_attribute_from_execution_strategy( | |||
| selected_strategy); | |||
| auto attr = extract_algo_attribute(selected_strategy); | |||
| policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||
| args..., workspace_limit, attr.first, | |||
| attr.second), | |||
| @@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||
| return result; | |||
| } | |||
| template <typename Opr> | |||
| std::pair<AlgoAttribute, AlgoAttribute> | |||
| AlgoChooser<Opr>::ExeContext::extract_algo_attribute( | |||
| const ExecutionStrategy& strategy) const { | |||
| std::pair<AlgoAttribute, AlgoAttribute> ret = | |||
| std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||
| //! from strategy | |||
| if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
| ret.first |= AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| if (strategy & ExecutionStrategy::OPTMIZED) { | |||
| ret.second |= AlgoAttribute::NAIVE; | |||
| } | |||
| return ret; | |||
| } | |||
| #define INST(Opr) \ | |||
| template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | |||
| const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | |||
| @@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||
| AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | |||
| const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | |||
| policy, \ | |||
| double& timeout) const; | |||
| double& timeout) const; \ | |||
| template std::pair<AlgoAttribute, AlgoAttribute> \ | |||
| AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \ | |||
| const ExecutionStrategy& strategy) const; | |||
| MGB_FOREACH_FASTRUN_OPR(INST) | |||
| @@ -149,6 +149,16 @@ public: | |||
| ImplExecutionPolicy& policy, | |||
| bool retrive_from_cache = true) const; | |||
| /** | |||
| * \brief extract algo attribute from execution strategy and graph | |||
| * option. | |||
| * | |||
| * \param strategy select algo which matched this strategy | |||
| * \return pair<positive_attr, negative_attr> | |||
| */ | |||
| std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute( | |||
| const ExecutionStrategy& strategy) const; | |||
| private: | |||
| Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | |||
| }; | |||