GitOrigin-RevId: 21d17e647a
tags/v1.5.0
| @@ -28,6 +28,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_STRIDE1"; } | const char* name() const override { return "ARM_POOLING_STRIDE1"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_FilterxModexStride1) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { | ||||
| @@ -38,6 +39,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_STRIDE2"; } | const char* name() const override { return "ARM_POOLING_STRIDE2"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -47,6 +49,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter3MaxStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { | ||||
| @@ -57,6 +60,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } | const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter3AverageStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { | ||||
| @@ -67,6 +71,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter4MaxStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { | ||||
| @@ -77,6 +82,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter5MaxStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { | ||||
| @@ -87,6 +93,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } | const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter2MaxStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { | ||||
| @@ -97,6 +104,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } | const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter3MaxStride2) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { | ||||
| @@ -107,6 +115,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter3ModexStridexNCHW44) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { | ||||
| @@ -117,6 +126,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStridexNCHW44) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { | ||||
| @@ -127,6 +137,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter4ModexStridexNCHW44) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { | ||||
| @@ -137,6 +148,7 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44) | |||||
| }; | }; | ||||
| class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -146,6 +158,17 @@ public: | |||||
| const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } | ||||
| bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
| void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Fp32ModexStridexNCHW44) | |||||
| }; | |||||
| class PoolingImpl::AlgoFallback final : public AlgoBase { | |||||
| public: | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| }; | |||||
| const char* name() const override { return "FALLBACK_POOLING"; } | |||||
| bool usable(const PoolingKernSizeParam&) const override { return true; } | |||||
| void exec(const PoolingKernParam&) const override {} | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_Fallback) | |||||
| }; | }; | ||||
| WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); | ||||
| @@ -12,11 +12,14 @@ | |||||
| #include "src/arm_common/pooling/opr_impl.h" | #include "src/arm_common/pooling/opr_impl.h" | ||||
| #include "src/arm_common/pooling/algo.h" | #include "src/arm_common/pooling/algo.h" | ||||
| #include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
| #include "src/common/algo_chooser.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace arm_common; | using namespace arm_common; | ||||
| class PoolingImpl::AlgoPack : NonCopyableObj { | class PoolingImpl::AlgoPack : NonCopyableObj { | ||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| AlgoFilterxModexStride1 algo_filterx_modex_stride1; | AlgoFilterxModexStride1 algo_filterx_modex_stride1; | ||||
| AlgoFilter2ModexStride2 algo_filter2_modex_stride2; | AlgoFilter2ModexStride2 algo_filter2_modex_stride2; | ||||
| AlgoFilter3MaxStride2 algo_filter3_max_stride2; | AlgoFilter3MaxStride2 algo_filter3_max_stride2; | ||||
| @@ -30,6 +33,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; | AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; | ||||
| AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; | AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; | ||||
| AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44; | AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44; | ||||
| AlgoFallback algo_fallback; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| @@ -46,10 +50,18 @@ public: | |||||
| all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); | all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); | ||||
| all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); | all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); | ||||
| all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44); | all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44); | ||||
| all_algos.emplace_back(&algo_fallback); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; | |||||
| PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( | PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( | ||||
| fallback::PoolingImpl* opr, const TensorLayout& src, | fallback::PoolingImpl* opr, const TensorLayout& src, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| @@ -89,44 +101,36 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( | |||||
| size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| bool find_algo = false; | |||||
| static AlgoPack m_algo_pack; | |||||
| auto param = make_pooling_kern_szie_param(this, src, dst); | auto param = make_pooling_kern_szie_param(this, src, dst); | ||||
| for (auto& m_algo : m_algo_pack.all_algos) { | |||||
| if (m_algo->usable(param)) { | |||||
| find_algo = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| size_t arm_common_workspace = 0; | |||||
| //! When multi-thread, every thread has its own workspace | |||||
| size_t nr_threads = static_cast<naive::HandleImpl*>(handle()) | |||||
| ->megcore_dispatcher() | |||||
| ->nr_threads(); | |||||
| if ((param.src_type.category() == DTypeCategory::FLOAT || | |||||
| param.src_type == dtype::Int8{} || | |||||
| param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||||
| param.filter[0] == param.filter[1] && | |||||
| (param.filter[0] == 3 || param.filter[0] == 5) && | |||||
| param.format == Param::Format::NCHW && | |||||
| (param.mode == Mode::MAX || | |||||
| (param.mode == Mode::AVERAGE && param.filter[0] == 3)) && | |||||
| param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && | |||||
| param.isz[1] >= 2) { | |||||
| WorkspaceBundle ws = get_bundle(param); | |||||
| arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||||
| } | |||||
| auto algo = get_algorithm(this, src, dst); | |||||
| if (!is_fallback_algo(algo)) { | |||||
| size_t arm_common_workspace = 0; | |||||
| if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.src_type.enumv() == DTypeEnum::Int8) && | |||||
| (param.format == param::Pooling::Format::NCHW44)) { | |||||
| WorkspaceBundle ws = get_bundle_nchw44(param); | |||||
| arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||||
| } | |||||
| //! When multi-thread, every thread has its own workspace | |||||
| size_t nr_threads = static_cast<naive::HandleImpl*>(handle()) | |||||
| ->megcore_dispatcher() | |||||
| ->nr_threads(); | |||||
| if ((param.src_type.category() == DTypeCategory::FLOAT || | |||||
| param.src_type == dtype::Int8{} || | |||||
| param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||||
| param.filter[0] == param.filter[1] && | |||||
| (param.filter[0] == 3 || param.filter[0] == 5) && | |||||
| param.format == Param::Format::NCHW && | |||||
| (param.mode == Mode::MAX || | |||||
| (param.mode == Mode::AVERAGE && param.filter[0] == 3)) && | |||||
| param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && | |||||
| param.isz[1] >= 2) { | |||||
| WorkspaceBundle ws = get_bundle(param); | |||||
| arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||||
| } | |||||
| if (find_algo) { | |||||
| if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.src_type.enumv() == DTypeEnum::Int8) && | |||||
| (param.format == param::Pooling::Format::NCHW44)) { | |||||
| WorkspaceBundle ws = get_bundle_nchw44(param); | |||||
| arm_common_workspace = ws.total_size_in_bytes() * nr_threads; | |||||
| } | |||||
| return arm_common_workspace; | return arm_common_workspace; | ||||
| } else { | } else { | ||||
| auto fallback_worksapce = | auto fallback_worksapce = | ||||
| @@ -139,14 +143,48 @@ void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
| auto param = make_pooling_kern_param(this, src, dst, workspace); | auto param = make_pooling_kern_param(this, src, dst, workspace); | ||||
| static AlgoPack m_algo_pack; | |||||
| for (auto& m_algo : m_algo_pack.all_algos) { | |||||
| if (m_algo->usable(param)) { | |||||
| m_algo->exec(param); | |||||
| return; | |||||
| auto algo = get_algorithm(this, src.layout, dst.layout); | |||||
| if (!is_fallback_algo(algo)) { | |||||
| algo->exec(param); | |||||
| } else { | |||||
| fallback::PoolingImpl::exec(src, dst, workspace); | |||||
| } | |||||
| } | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingImpl); | |||||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& dst) { | |||||
| auto param = make_pooling_kern_szie_param(this, src, dst); | |||||
| std::vector<Algorithm*> ret; | |||||
| ret.reserve(algo_pack().all_algos.size()); | |||||
| for (auto i : algo_pack().all_algos) { | |||||
| if (i->usable(param)) { | |||||
| ret.push_back(i); | |||||
| } | |||||
| } | |||||
| megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); | |||||
| return ret; | |||||
| } | |||||
| Algorithm* PoolingImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | |||||
| auto param = make_pooling_kern_szie_param(this, src, dst); | |||||
| for (auto&& iter : sm_algo_pack.all_algos) { | |||||
| if (iter->is_available_attribute(param, positive_attr, negative_attr)) { | |||||
| return iter; | |||||
| } | } | ||||
| } | } | ||||
| fallback::PoolingImpl::exec(src, dst, workspace); | |||||
| megdnn_throw( | |||||
| ssprintf("require algorithm with attribute(%s) and without " | |||||
| "attribute(%s), but can't get suitable algo.\n", | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str())); | |||||
| return nullptr; | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -12,11 +12,30 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs/base.h" | #include "megdnn/oprs/base.h" | ||||
| #include "src/fallback/pooling/opr_impl.h" | #include "src/fallback/pooling/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| class PoolingImpl final : public fallback::PoolingImpl { | class PoolingImpl final : public fallback::PoolingImpl { | ||||
| private: | |||||
| class AlgoFilterxModexStride1; | |||||
| class AlgoFilter2ModexStride2; | |||||
| class AlgoFilter3MaxStride2; | |||||
| class AlgoFilter3AverageStride2; | |||||
| class AlgoFilter4MaxStride2; | |||||
| class AlgoFilter5MaxStride2; | |||||
| class AlgoInt8Filter2MaxStride2; | |||||
| class AlgoInt8Filter3MaxStride2; | |||||
| class AlgoFilter2ModexStridexNCHW44; | |||||
| class AlgoFilter3ModexStridexNCHW44; | |||||
| class AlgoFilter4ModexStridexNCHW44; | |||||
| class AlgoFilter5ModexStridexNCHW44; | |||||
| class AlgoFp32ModexStridexNCHW44; | |||||
| class AlgoFallback; | |||||
| class AlgoPack; | |||||
| static AlgoPack sm_algo_pack; | |||||
| public: | public: | ||||
| using fallback::PoolingImpl::PoolingImpl; | using fallback::PoolingImpl::PoolingImpl; | ||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
| @@ -70,28 +89,68 @@ public: | |||||
| _megdnn_workspace workspace); | _megdnn_workspace workspace); | ||||
| class AlgoBase : public detail::Algorithm { | class AlgoBase : public detail::Algorithm { | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| ARM_FilterxModexStride1, | |||||
| ARM_Filter2ModexStride2, | |||||
| ARM_Filter3MaxStride2, | |||||
| ARM_Filter3AverageStride2, | |||||
| ARM_Filter4MaxStride2, | |||||
| ARM_Filter5MaxStride2, | |||||
| ARM_Int8Filter2MaxStride2, | |||||
| ARM_Int8Filter3MaxStride2, | |||||
| ARM_Filter2ModexStridexNCHW44, | |||||
| ARM_Filter3ModexStridexNCHW44, | |||||
| ARM_Filter4ModexStridexNCHW44, | |||||
| ARM_Filter5ModexStridexNCHW44, | |||||
| ARM_Fp32ModexStridexNCHW44, | |||||
| ARM_Fallback | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ARM_COMMON; } | |||||
| virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
| virtual bool usable(const PoolingKernSizeParam& param) const = 0; | virtual bool usable(const PoolingKernSizeParam& param) const = 0; | ||||
| virtual void exec(const PoolingKernParam& param) const = 0; | virtual void exec(const PoolingKernParam& param) const = 0; | ||||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | uint32_t type() const override { return INVALID_ALGO_TYPE; }; | ||||
| bool is_available_attribute( | |||||
| const PoolingKernSizeParam& param, | |||||
| const AlgoAttribute& positive_attr = | |||||
| AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && usable(param); | |||||
| } | |||||
| }; | }; | ||||
| private: | |||||
| class AlgoFilterxModexStride1; | |||||
| class AlgoFilter2ModexStride2; | |||||
| class AlgoFilter3MaxStride2; | |||||
| class AlgoFilter3AverageStride2; | |||||
| class AlgoFilter4MaxStride2; | |||||
| class AlgoFilter5MaxStride2; | |||||
| class AlgoInt8Filter2MaxStride2; | |||||
| class AlgoInt8Filter3MaxStride2; | |||||
| class AlgoFilter2ModexStridexNCHW44; | |||||
| class AlgoFilter3ModexStridexNCHW44; | |||||
| class AlgoFilter4ModexStridexNCHW44; | |||||
| class AlgoFilter5ModexStridexNCHW44; | |||||
| class AlgoFp32ModexStridexNCHW44; | |||||
| class AlgoPack; | |||||
| const char* get_algorithm_set_name() const override { | |||||
| return "ARM_POOLING_FORWARD"; | |||||
| } | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | |||||
| } | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| bool is_fallback_algo(Algorithm* algo) { | |||||
| return strcmp(algo->name(), "FALLBACK_POOLING") == 0; | |||||
| } | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||