GitOrigin-RevId: 479718ac75
tags/v1.2.0
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t { | |||||
| /*! | /*! | ||||
| * \brief Abstract representation of an algorithm for implementing | * \brief Abstract representation of an algorithm for implementing | ||||
| * the operator | * the operator | ||||
| * | |||||
| * All pointers to Algorithm should be allocated globally and usable | |||||
| * across multiple megdnn handles, and they should not be freed by | |||||
| * the caller. | |||||
| */ | */ | ||||
| class Algorithm { | class Algorithm { | ||||
| public: | public: | ||||
| static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1); | |||||
| /** | |||||
| * \brief Algorithm information, we can get real algo from | |||||
| * AlgorithmInfo::Info::Desc | |||||
| */ | |||||
| struct Info { | |||||
| struct Desc { | |||||
| //! backend of the algo belonging to | |||||
| Handle::HandleType handle_type; | |||||
| //! indicate the real algo implementation | |||||
| uint32_t type = INVALID_ALGO_TYPE; | |||||
| //! serialized param of the algo type | |||||
| std::string param; | |||||
| bool valid() const { return type != INVALID_ALGO_TYPE; } | |||||
| void reset() { type = INVALID_ALGO_TYPE; } | |||||
| bool operator==(const Desc& rhs) const { | |||||
| return handle_type == rhs.handle_type && type == rhs.type && | |||||
| param == rhs.param; | |||||
| } | |||||
| } desc; | |||||
| //! algorithm name | |||||
| std::string name; | |||||
| bool is_reproducible; | |||||
| bool valid() const { return desc.valid(); } | |||||
| void reset() { desc.reset(); } | |||||
| //! desc donate the algo | |||||
| bool operator==(const Info& rhs) const { return desc == rhs.desc; } | |||||
| }; | |||||
| virtual ~Algorithm() = default; | |||||
| /** | /** | ||||
| * \brief whether the execution result is | * \brief whether the execution result is | ||||
| * reproducible across multiple runs. | * reproducible across multiple runs. | ||||
| */ | */ | ||||
| virtual bool is_reproducible() const = 0; | virtual bool is_reproducible() const = 0; | ||||
| virtual const char* name() const = 0; | virtual const char* name() const = 0; | ||||
| //! serialized param | |||||
| virtual std::string param() const { return {}; } | |||||
| virtual uint32_t type() const = 0; | |||||
| Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
| Info info() const { | |||||
| return {{handle_type(), type(), param()}, name(), is_reproducible()}; | |||||
| } | |||||
| template <typename T> | |||||
| static void serialize_write_pod(const T& val, std::string& result) { | |||||
| result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | |||||
| } | |||||
| static void serialize_write_pod(const char* val, std::string& result) { | |||||
| result.append(val, strlen(val)); | |||||
| } | |||||
| template <typename T> | |||||
| static T deserialize_read_pod(const std::string& data, size_t offset = 0) { | |||||
| T ret = *reinterpret_cast<const T*>(&data[offset]); | |||||
| return ret; | |||||
| } | |||||
| protected: | protected: | ||||
| ~Algorithm() = default; | |||||
| Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | ||||
| }; | }; | ||||
| @@ -127,6 +176,8 @@ class MultiAlgoOpr; | |||||
| template <class Opr> | template <class Opr> | ||||
| class MultiAlgoOpr<Opr, -1> { | class MultiAlgoOpr<Opr, -1> { | ||||
| public: | public: | ||||
| using AlgorithmInfo = detail::Algorithm::Info; | |||||
| using AlgorithmDesc = detail::Algorithm::Info::Desc; | |||||
| using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
| /*! | /*! | ||||
| * \brief get a string representation for current algorithm set; | * \brief get a string representation for current algorithm set; | ||||
| @@ -139,8 +190,8 @@ public: | |||||
| //! policy for executing the operator | //! policy for executing the operator | ||||
| struct ExecutionPolicy { | struct ExecutionPolicy { | ||||
| //! nullptr means using heuristic | |||||
| Algorithm* algorithm = nullptr; | |||||
| //! INVALID_ALGO_TYPE algo_type means using heuristic | |||||
| AlgorithmInfo algo; | |||||
| }; | }; | ||||
| ExecutionPolicy& execution_policy() { return m_execution_policy; } | ExecutionPolicy& execution_policy() { return m_execution_policy; } | ||||
| @@ -161,6 +212,39 @@ template <class Opr> | |||||
| class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | ||||
| public: | public: | ||||
| using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
| using AlgorithmInfo = detail::Algorithm::Info; | |||||
| //! get all possible algorithm decriptions for the specified layouts | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | |||||
| * \brief Returns the best algorithm information which indicate the | |||||
| * algorithm by heuristic. | |||||
| * | |||||
| * The selected algorithm should not use workspace more than | |||||
| * \p workspace_limit_in_bytes. | |||||
| */ | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| bool reproducible = false) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||||
| reproducible) | |||||
| ->info(); | |||||
| } | |||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| @@ -179,9 +263,6 @@ public: | |||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| bool reproducible = false) = 0; | bool reproducible = false) = 0; | ||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| }; | }; | ||||
| //! specializae for nargs == 4 | //! specializae for nargs == 4 | ||||
| @@ -189,6 +270,40 @@ template <class Opr> | |||||
| class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | ||||
| public: | public: | ||||
| using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
| using AlgorithmInfo = detail::Algorithm::Info; | |||||
| //! get all possible algorithm decriptions for the specified layouts | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | |||||
| * \brief Returns the best algorithm information which indicate the | |||||
| * algorithm by heuristic. | |||||
| * | |||||
| * The selected algorithm should not use workspace more than | |||||
| * \p workspace_limit_in_bytes. | |||||
| */ | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| bool reproducible = false) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||||
| reproducible) | |||||
| ->info(); | |||||
| } | |||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| @@ -207,9 +322,6 @@ public: | |||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| bool reproducible = false) = 0; | bool reproducible = false) = 0; | ||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| }; | }; | ||||
| //! specializae for nargs == 5 | //! specializae for nargs == 5 | ||||
| @@ -217,6 +329,42 @@ template <class Opr> | |||||
| class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | ||||
| public: | public: | ||||
| using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
| using AlgorithmInfo = detail::Algorithm::Info; | |||||
| //! get all possible algorithm decriptions for the specified layouts | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | |||||
| * \brief Returns the best algorithm information which indicate the | |||||
| * algorithm by heuristic. | |||||
| * | |||||
| * The selected algorithm should not use workspace more than | |||||
| * \p workspace_limit_in_bytes. | |||||
| */ | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| bool reproducible = false) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| @@ -237,9 +385,6 @@ public: | |||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| bool reproducible = false) = 0; | bool reproducible = false) = 0; | ||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| }; | }; | ||||
| //! specializae for nargs == 8 | //! specializae for nargs == 8 | ||||
| @@ -247,6 +392,42 @@ template <class Opr> | |||||
| class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | ||||
| public: | public: | ||||
| using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
| using AlgorithmInfo = detail::Algorithm::Info; | |||||
| //! get all possible algorithm decriptions for the specified layouts | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | |||||
| * \brief Returns the best algorithm information which indicate the | |||||
| * algorithm by heuristic. | |||||
| * | |||||
| * The selected algorithm should not use workspace more than | |||||
| */ | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| bool reproducible = false) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| @@ -269,9 +450,6 @@ public: | |||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| bool reproducible = false) = 0; | bool reproducible = false) = 0; | ||||
| protected: | |||||
| ~MultiAlgoOpr() = default; | |||||
| }; | }; | ||||
| } // namespace detail | } // namespace detail | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -31,6 +31,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16) | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -36,6 +36,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32) | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -48,6 +48,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8) | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoF16DirectStride2 f16_direct_stride2; | AlgoF16DirectStride2 f16_direct_stride2; | ||||
| #endif | #endif | ||||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_matmul_algos; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| matmul_algos.emplace_back(&qu8_matrix_mul); | |||||
| matmul_algos.emplace_back(&s8_matrix_mul); | |||||
| m_matmul_algos.emplace_back(&qu8_matrix_mul); | |||||
| m_matmul_algos.emplace_back(&s8_matrix_mul); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| direct_algos.emplace_back(&f16_direct_stride2); | |||||
| m_direct_algos.emplace_back(&f16_direct_stride2); | |||||
| #endif | #endif | ||||
| direct_algos.emplace_back(&f32_direct_stride2); | |||||
| m_direct_algos.emplace_back(&f32_direct_stride2); | |||||
| for (auto&& algo : m_direct_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| for (auto&& algo : m_matmul_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||||
| return m_direct_algos; | |||||
| } | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||||
| const { | |||||
| return m_matmul_algos; | |||||
| } | } | ||||
| SmallVector<AlgoBase*> direct_algos; | |||||
| SmallVector<AlgoBase*> matmul_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack sl_algo_pack; | |||||
| auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||||
| algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||||
| sl_algo_pack.direct_algos.end()); | |||||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||||
| ConvBiasImpl::get_all_packed_algo() { | |||||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||||
| algo_pack().direct_algos().end()); | |||||
| //! We put matmul algos at the begin. Because matmul will get privilege when | //! We put matmul algos at the begin. Because matmul will get privilege when | ||||
| //! prefer return true. See | //! prefer return true. See | ||||
| algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), | |||||
| sl_algo_pack.matmul_algos.end()); | |||||
| algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||||
| algo_pack().matmul_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -25,7 +25,9 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||||
| protected: | protected: | ||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -38,6 +40,7 @@ private: | |||||
| class AlgoF16DirectStride2; | class AlgoF16DirectStride2; | ||||
| #endif | #endif | ||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -48,6 +48,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8) | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -27,6 +27,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
| @@ -37,6 +38,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | ||||
| @@ -47,6 +49,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | ||||
| @@ -58,10 +61,17 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32Gemv final | class MatrixMulImpl::AlgoF32Gemv final | ||||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
| public: | |||||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||||
| m_handle_type = Handle::HandleType::AARCH64; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV) | |||||
| }; | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | ||||
| @@ -72,6 +82,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | ||||
| @@ -83,6 +94,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -98,6 +110,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | ||||
| @@ -110,6 +123,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) | |||||
| }; | }; | ||||
| #else | #else | ||||
| @@ -124,6 +138,7 @@ public: | |||||
| PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | ||||
| @@ -136,6 +151,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | ||||
| @@ -147,6 +163,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -160,6 +177,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | ||||
| @@ -171,6 +189,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
| @@ -186,6 +205,7 @@ public: | |||||
| PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
| @@ -201,6 +221,7 @@ public: | |||||
| PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
| @@ -214,6 +235,7 @@ public: | |||||
| PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | ||||
| @@ -225,6 +247,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | ||||
| @@ -236,6 +259,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -249,6 +273,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
| @@ -262,6 +287,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) | |||||
| }; | }; | ||||
| #else | #else | ||||
| @@ -273,6 +299,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoQuint8K8x8x8 quint8_k8x8x8; | AlgoQuint8K8x8x8 quint8_k8x8x8; | ||||
| #endif | #endif | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
| AlgoPack() { | AlgoPack() { | ||||
| all_algos.emplace_back(&f32_gemv); | |||||
| all_algos.emplace_back(&f32K8x12x1); | |||||
| all_algos.emplace_back(&f32_mk4_8x12x1); | |||||
| all_algos.emplace_back(&f32k4x16x1); | |||||
| all_algos.emplace_back(&f32mk4_4x16); | |||||
| m_all_algos.emplace_back(&f32_gemv); | |||||
| m_all_algos.emplace_back(&f32K8x12x1); | |||||
| m_all_algos.emplace_back(&f32_mk4_8x12x1); | |||||
| m_all_algos.emplace_back(&f32k4x16x1); | |||||
| m_all_algos.emplace_back(&f32mk4_4x16); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| all_algos.emplace_back(&f16_k8x24x1); | |||||
| all_algos.emplace_back(&f16_mk8_8x8); | |||||
| m_all_algos.emplace_back(&f16_k8x24x1); | |||||
| m_all_algos.emplace_back(&f16_mk8_8x8); | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||||
| all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||||
| m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||||
| #else | #else | ||||
| all_algos.emplace_back(&int8x8x32_k4x4x16); | |||||
| all_algos.emplace_back(&int8x8x32_k8x8x8); | |||||
| all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||||
| m_all_algos.emplace_back(&int8x8x32_k4x4x16); | |||||
| m_all_algos.emplace_back(&int8x8x32_k8x8x8); | |||||
| m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||||
| #endif | #endif | ||||
| all_algos.emplace_back(&int8x8x16_k4x4x16); | |||||
| all_algos.emplace_back(&int8x8x16_k8x8x8); | |||||
| all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||||
| all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||||
| all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||||
| m_all_algos.emplace_back(&int8x8x16_k4x4x16); | |||||
| m_all_algos.emplace_back(&int8x8x16_k8x8x8); | |||||
| m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||||
| m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||||
| m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||||
| all_algos.emplace_back(&int16x16x32_k12x8x1); | |||||
| all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||||
| m_all_algos.emplace_back(&int16x16x32_k12x8x1); | |||||
| m_all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| all_algos.emplace_back(&quint8_gemv_dotprod); | |||||
| all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||||
| m_all_algos.emplace_back(&quint8_gemv_dotprod); | |||||
| m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||||
| #else | #else | ||||
| all_algos.emplace_back(&quint8_k8x8x8); | |||||
| m_all_algos.emplace_back(&quint8_k8x8x8); | |||||
| #endif | #endif | ||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||||
| return m_all_algos; | |||||
| } | } | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
| static AlgoPack s_algo_pack; | |||||
| auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | |||||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||||
| s_algo_pack.all_algos.end()); | |||||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||||
| MatrixMulImpl::get_all_packed_algo() { | |||||
| auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||||
| algo_pack().all_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -25,7 +25,10 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||||
| override; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||||
| private: | private: | ||||
| class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | ||||
| @@ -66,6 +69,8 @@ private: | |||||
| class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | ||||
| class AlgoPack; | class AlgoPack; | ||||
| public: | |||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -30,6 +30,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | ||||
| @@ -45,7 +46,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -61,6 +62,7 @@ public: | |||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -75,6 +77,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | ||||
| @@ -94,6 +97,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override{ | ConvAlgoTypePack get_algo_type() const override{ | ||||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | ||||
| @@ -110,6 +114,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -30,6 +30,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | ||||
| @@ -45,6 +46,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | ||||
| @@ -60,6 +62,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | ||||
| @@ -75,6 +78,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | ||||
| @@ -90,6 +94,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | |||||
| }; | }; | ||||
| //===================== NCHW44 Winograd Support =====================// | //===================== NCHW44 Winograd Support =====================// | ||||
| @@ -107,6 +112,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | ||||
| @@ -123,6 +129,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | ||||
| @@ -139,6 +146,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||||
| }; | }; | ||||
| // ================================================================= // | // ================================================================= // | ||||
| @@ -157,6 +165,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | ||||
| @@ -174,6 +183,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
| @@ -191,6 +201,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | ||||
| @@ -209,6 +220,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | ||||
| @@ -227,6 +239,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | ||||
| @@ -244,6 +257,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -33,6 +33,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
| @@ -49,6 +50,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | ||||
| @@ -65,6 +67,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | ||||
| @@ -81,6 +84,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | ||||
| @@ -95,6 +99,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | ||||
| @@ -109,6 +114,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -126,6 +132,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
| @@ -142,6 +149,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | ||||
| @@ -159,6 +167,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | ||||
| @@ -180,6 +189,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -196,6 +206,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) | |||||
| }; | }; | ||||
| //=======================input int8 compute fp32 output int8============ | //=======================input int8 compute fp32 output int8============ | ||||
| @@ -213,6 +224,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) | |||||
| }; | }; | ||||
| //=======================input int8 compute int16 output int8============ | //=======================input int8 compute int16 output int8============ | ||||
| @@ -231,6 +243,7 @@ public: | |||||
| } | } | ||||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -39,6 +39,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | ||||
| @@ -54,6 +55,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | ||||
| @@ -80,6 +82,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | ||||
| @@ -96,12 +99,16 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | |||||
| class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final | |||||
| : public AlgoBase { | |||||
| public: | public: | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } | |||||
| const char* name() const override { | |||||
| return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; | |||||
| } | |||||
| bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
| size_t get_workspace( | size_t get_workspace( | ||||
| @@ -111,6 +118,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | ||||
| @@ -129,6 +137,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16) | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| #endif | #endif | ||||
| SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| direct_algos.emplace_back(&ds8_direct_stride1); | |||||
| direct_algos.emplace_back(&ds8_direct_stride2); | |||||
| direct_algos.emplace_back(&du8_direct_stride1); | |||||
| direct_algos.emplace_back(&du8_direct_stride2); | |||||
| m_direct_algos.emplace_back(&ds8_direct_stride1); | |||||
| m_direct_algos.emplace_back(&ds8_direct_stride2); | |||||
| m_direct_algos.emplace_back(&du8_direct_stride1); | |||||
| m_direct_algos.emplace_back(&du8_direct_stride2); | |||||
| direct_algos.emplace_back(&ds8_direct_nchw44); | |||||
| direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||||
| m_direct_algos.emplace_back(&ds8_direct_nchw44); | |||||
| m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||||
| #endif | #endif | ||||
| direct_algos.emplace_back(&qu8_direct_stride2); | |||||
| direct_algos.emplace_back(&qu8_direct_stride1); | |||||
| direct_algos.emplace_back(&s8_direct_stride2); | |||||
| direct_algos.emplace_back(&s8_direct_nchw44); | |||||
| direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||||
| direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||||
| direct_algos.emplace_back(&s8_direct_stride1); | |||||
| direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); | |||||
| direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||||
| direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||||
| m_direct_algos.emplace_back(&qu8_direct_stride2); | |||||
| m_direct_algos.emplace_back(&qu8_direct_stride1); | |||||
| m_direct_algos.emplace_back(&s8_direct_stride2); | |||||
| m_direct_algos.emplace_back(&s8_direct_nchw44); | |||||
| m_direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||||
| m_direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||||
| m_direct_algos.emplace_back(&s8_direct_stride1); | |||||
| m_direct_algos.emplace_back( | |||||
| &s8x8x16_channel_wise_stride1_stride2_nchw44); | |||||
| m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||||
| m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| direct_algos.emplace_back(&f16_direct_stride1); | |||||
| direct_algos.emplace_back(&f16_direct); | |||||
| m_direct_algos.emplace_back(&f16_direct_stride1); | |||||
| m_direct_algos.emplace_back(&f16_direct); | |||||
| #endif | #endif | ||||
| direct_algos.emplace_back(&i8x8x16_direct); | |||||
| direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||||
| direct_algos.emplace_back(&i8x8x16_stride2); | |||||
| direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||||
| m_direct_algos.emplace_back(&i8x8x16_direct); | |||||
| m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||||
| m_direct_algos.emplace_back(&i8x8x16_stride2); | |||||
| m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||||
| direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||||
| direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||||
| direct_algos.emplace_back(&f32_direct_nchw44); | |||||
| m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||||
| m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||||
| m_direct_algos.emplace_back(&f32_direct_nchw44); | |||||
| direct_algos.emplace_back(&f32_direct_stride1); | |||||
| direct_algos.emplace_back(&f32_direct_stride2); | |||||
| direct_algos.emplace_back(&f32_direct); | |||||
| m_direct_algos.emplace_back(&f32_direct_stride1); | |||||
| m_direct_algos.emplace_back(&f32_direct_stride2); | |||||
| m_direct_algos.emplace_back(&f32_direct); | |||||
| static CpuOprDelegationStorage<2> storage; | static CpuOprDelegationStorage<2> storage; | ||||
| auto matmul_opr = storage.get<MatrixMul, 0>(); | auto matmul_opr = storage.get<MatrixMul, 0>(); | ||||
| @@ -143,31 +147,31 @@ public: | |||||
| refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| //! uncomment this when low precision mode is done | //! uncomment this when low precision mode is done | ||||
| #if 0 | #if 0 | ||||
| refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| #endif | #endif | ||||
| //! Qint8x8x32 winograd compute with fp32 | //! Qint8x8x32 winograd compute with fp32 | ||||
| refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| } | } | ||||
| matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | ||||
| @@ -180,15 +184,15 @@ public: | |||||
| refhold.emplace_back(new AlgoFP32WinogradF63( | refhold.emplace_back(new AlgoFP32WinogradF63( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP32WinogradF54( | refhold.emplace_back(new AlgoFP32WinogradF54( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP32WinogradF45( | refhold.emplace_back(new AlgoFP32WinogradF45( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -203,15 +207,15 @@ public: | |||||
| refhold.emplace_back(new AlgoFP16WinogradF23( | refhold.emplace_back(new AlgoFP16WinogradF23( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP16WinogradF45( | refhold.emplace_back(new AlgoFP16WinogradF45( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoFP16WinogradF63( | refhold.emplace_back(new AlgoFP16WinogradF63( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| } | } | ||||
| matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | ||||
| @@ -224,7 +228,7 @@ public: | |||||
| refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -238,25 +242,48 @@ public: | |||||
| refhold.emplace_back(new AlgoS8WinogradF23_8x8( | refhold.emplace_back(new AlgoS8WinogradF23_8x8( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | ||||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
| tile_size)); | tile_size)); | ||||
| winograd_algos.emplace_back(refhold.back().get()); | |||||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| } | } | ||||
| for (auto&& algo : m_direct_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| for (auto&& algo : m_winograd_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| SmallVector<AlgoBase*> direct_algos; | |||||
| SmallVector<AlgoBase*> winograd_algos; | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() | |||||
| const { | |||||
| return m_direct_algos; | |||||
| } | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos() | |||||
| const { | |||||
| return m_winograd_algos; | |||||
| } | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack sl_algo_pack; | |||||
| auto&& algos = fallback::ConvBiasImpl::algo_pack(); | |||||
| algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||||
| sl_algo_pack.direct_algos.end()); | |||||
| algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), | |||||
| sl_algo_pack.winograd_algos.end()); | |||||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||||
| ConvBiasImpl::get_all_packed_algo() { | |||||
| auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||||
| algo_pack().direct_algos().end()); | |||||
| algos.insert(algos.end(), algo_pack().winograd_algos().begin(), | |||||
| algo_pack().winograd_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -12,6 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #include "src/common/algo_base.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -27,7 +28,7 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||||
| bool is_matmul_quantized_prefer( | bool is_matmul_quantized_prefer( | ||||
| const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | ||||
| @@ -35,7 +36,8 @@ public: | |||||
| SmallVector<AlgoCategory> suggest_algo_category_order( | SmallVector<AlgoCategory> suggest_algo_category_order( | ||||
| const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
| class AlgoPack; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||||
| protected: | protected: | ||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -95,6 +97,9 @@ private: | |||||
| class AlgoF16Direct; | class AlgoF16Direct; | ||||
| class AlgoF16DirectStride1; | class AlgoF16DirectStride1; | ||||
| #endif | #endif | ||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -32,6 +32,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | ||||
| @@ -48,6 +49,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | ||||
| @@ -65,6 +67,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | ||||
| @@ -81,6 +84,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -36,6 +36,7 @@ public: | |||||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32) | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | ||||
| @@ -54,6 +55,7 @@ public: | |||||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | ||||
| FH >= PH + 1 && FW >= PW + 1; | FH >= PH + 1 && FW >= PW + 1; | ||||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||||
| return avaiable && | return avaiable && | ||||
| ((FH == 2 && OC <= 8) || | ((FH == 2 && OC <= 8) || | ||||
| ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); | ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); | ||||
| @@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | ||||
| FH >= PH + 1 && FW >= PW + 1; | FH >= PH + 1 && FW >= PW + 1; | ||||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||||
| return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || | return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || | ||||
| (FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); | (FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); | ||||
| } | } | ||||
| @@ -23,15 +23,54 @@ using namespace arm_common; | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| struct ConvolutionBackwardDataImpl::AlgoPack { | |||||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | ||||
| AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | ||||
| AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | ||||
| AlgoUdot8DirectStride2 quint8_direct_stride2_udot; | AlgoUdot8DirectStride2 quint8_direct_stride2_udot; | ||||
| #endif | #endif | ||||
| fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||||
| m_all_algos; | |||||
| public: | |||||
| AlgoPack() { | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); | |||||
| m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); | |||||
| m_all_algos.emplace_back(&quint8_direct_stride1_udot); | |||||
| m_all_algos.emplace_back(&quint8_direct_stride2_udot); | |||||
| #endif | |||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| const SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>& | |||||
| all_algos() const { | |||||
| return m_all_algos; | |||||
| } | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | |||||
| const ConvolutionBackwardDataImpl::AlgoPack& | |||||
| ConvolutionBackwardDataImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||||
| ConvolutionBackwardDataImpl::get_all_packed_algo() { | |||||
| auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||||
| algo_pack().all_algos().end()); | |||||
| return std::move(algos); | |||||
| } | |||||
| ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
| ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | ||||
| @@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | |||||
| param); | param); | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
| ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
| const NCBKernSizeParam& param) { | |||||
| auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
| param); | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||||
| param.grad_type.enumv() == DTypeEnum::Int32)) { | |||||
| if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { | |||||
| ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); | |||||
| } | |||||
| if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { | |||||
| ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); | |||||
| } | |||||
| } else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
| param.grad_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
| if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { | |||||
| ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); | |||||
| } | |||||
| if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) { | |||||
| ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| return ret; | |||||
| } | |||||
| const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { | const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { | ||||
| // arm common version 0 | // arm common version 0 | ||||
| return "DeconvAC0"; | return "DeconvAC0"; | ||||
| @@ -47,11 +47,14 @@ protected: | |||||
| size_t ncb_1g_get_workspace(Algorithm* algo, | size_t ncb_1g_get_workspace(Algorithm* algo, | ||||
| const NCBKernSizeParam& param) override; | const NCBKernSizeParam& param) override; | ||||
| std::vector<Algorithm*> ncb_1g_get_all_algorithms( | |||||
| const NCBKernSizeParam& param) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||||
| get_all_packed_algo() override; | |||||
| public: | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); | |||||
| private: | private: | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| class AlgoSdot8DirectStride1; | class AlgoSdot8DirectStride1; | ||||
| @@ -59,8 +62,8 @@ private: | |||||
| class AlgoUdot8DirectStride1; | class AlgoUdot8DirectStride1; | ||||
| class AlgoUdot8DirectStride2; | class AlgoUdot8DirectStride2; | ||||
| #endif | #endif | ||||
| struct AlgoPack; | |||||
| static AlgoPack sm_algo_pack; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -36,6 +36,7 @@ public: | |||||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | ||||
| @@ -55,6 +56,7 @@ public: | |||||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | ||||
| FH >= PH + 1 && FW >= PW + 1; | FH >= PH + 1 && FW >= PW + 1; | ||||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||||
| /** | /** | ||||
| * \note In the kernel, we use int32_t to calc the value, in order | * \note In the kernel, we use int32_t to calc the value, in order | ||||
| * not generate negative number, we first initialize SHIFT and sub | * not generate negative number, we first initialize SHIFT and sub | ||||
| @@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | ||||
| FH >= PH + 1 && FW >= PW + 1; | FH >= PH + 1 && FW >= PW + 1; | ||||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||||
| /** | /** | ||||
| * \note In the kernel, we use uint32_t to calc the value, in order | * \note In the kernel, we use uint32_t to calc the value, in order | ||||
| * not generate negative number, we first initialize SHIFT and sub | * not generate negative number, we first initialize SHIFT and sub | ||||
| @@ -59,6 +59,7 @@ public: | |||||
| virtual bool is_available(const KernParam&) const = 0; | virtual bool is_available(const KernParam&) const = 0; | ||||
| virtual void exec(const KernParam&) const = 0; | virtual void exec(const KernParam&) const = 0; | ||||
| virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -26,6 +26,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
| @@ -39,6 +40,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
| @@ -52,6 +54,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -66,6 +69,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -96,6 +100,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -110,6 +115,7 @@ public: | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -130,6 +136,7 @@ public: | |||||
| static_cast<uint32_t>(AlgoDataType::FLOAT32) | | static_cast<uint32_t>(AlgoDataType::FLOAT32) | | ||||
| static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | ||||
| DEFAULT) | DEFAULT) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM) | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoGevm gevm; | AlgoGevm gevm; | ||||
| AlgoF32GemvMK4 f32_gemv_mk4; | AlgoF32GemvMK4 f32_gemv_mk4; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| all_algos.emplace_back(&int8x8x16); | |||||
| m_all_algos.emplace_back(&int8x8x16); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| all_algos.emplace_back(&f16gemv); | |||||
| m_all_algos.emplace_back(&f16gemv); | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||||
| #endif | #endif | ||||
| all_algos.emplace_back(&int8x8x32_gemv); | |||||
| all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||||
| all_algos.emplace_back(&f32_gemv_mk4); | |||||
| all_algos.emplace_back(&gevm); | |||||
| m_all_algos.emplace_back(&int8x8x32_gemv); | |||||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||||
| m_all_algos.emplace_back(&f32_gemv_mk4); | |||||
| m_all_algos.emplace_back(&gevm); | |||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||||
| return m_all_algos; | |||||
| } | } | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||||
| MatrixMulImpl::get_all_packed_algo() { | |||||
| static AlgoPack s_algo_pack; | static AlgoPack s_algo_pack; | ||||
| auto&& algos = fallback::MatrixMulImpl::algo_pack(); | |||||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||||
| s_algo_pack.all_algos.end()); | |||||
| auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||||
| algo_pack().all_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -11,6 +11,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/matrix_mul/opr_impl.h" | #include "src/fallback/matrix_mul/opr_impl.h" | ||||
| #include "src/common/algo_base.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -27,7 +28,10 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||||
| override; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||||
| protected: | protected: | ||||
| class AlgoF32Gemv; // Arm_common F32 Gemv | class AlgoF32Gemv; // Arm_common F32 Gemv | ||||
| @@ -43,6 +47,9 @@ protected: | |||||
| #endif | #endif | ||||
| class AlgoInt8x8x16; // Arm_common Int 8x8x16 | class AlgoInt8x8x16; // Arm_common Int 8x8x16 | ||||
| class AlgoPack; | class AlgoPack; | ||||
| public: | |||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -10,6 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs/base.h" | |||||
| #include "src/fallback/pooling/opr_impl.h" | #include "src/fallback/pooling/opr_impl.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -72,6 +73,8 @@ public: | |||||
| virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
| virtual bool usable(const PoolingKernSizeParam& param) const = 0; | virtual bool usable(const PoolingKernSizeParam& param) const = 0; | ||||
| virtual void exec(const PoolingKernParam& param) const = 0; | virtual void exec(const PoolingKernParam& param) const = 0; | ||||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||||
| }; | }; | ||||
| private: | private: | ||||
| @@ -40,6 +40,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8) | |||||
| }; | }; | ||||
| } // namespace armv7 | } // namespace armv7 | ||||
| @@ -24,22 +24,40 @@ using namespace armv7; | |||||
| class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
| AlgoS8MatrixMul s8_matrix_mul; | AlgoS8MatrixMul s8_matrix_mul; | ||||
| AlgoQU8MatrixMul qu8_matrix_mul; | AlgoQU8MatrixMul qu8_matrix_mul; | ||||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_algos; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| all_algos.emplace_back(&qu8_matrix_mul); | |||||
| all_algos.emplace_back(&s8_matrix_mul); | |||||
| m_all_algos.emplace_back(&qu8_matrix_mul); | |||||
| m_all_algos.emplace_back(&s8_matrix_mul); | |||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_algos() | |||||
| const { | |||||
| return m_all_algos; | |||||
| } | } | ||||
| SmallVector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack sl_algo_pack; | |||||
| auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||||
| ConvBiasImpl::get_all_packed_algo() { | |||||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||||
| //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | ||||
| //! and nearly equal in aarch64, because of the waste of register in | //! and nearly equal in aarch64, because of the waste of register in | ||||
| //! postprocess | //! postprocess | ||||
| algos.insert(algos.end(), sl_algo_pack.all_algos.begin(), | |||||
| sl_algo_pack.all_algos.end()); | |||||
| algos.insert(algos.end(), algo_pack().all_algos().begin(), | |||||
| algo_pack().all_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -25,7 +25,9 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||||
| protected: | protected: | ||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -34,6 +36,7 @@ private: | |||||
| class AlgoS8MatrixMul; | class AlgoS8MatrixMul; | ||||
| class AlgoQU8MatrixMul; | class AlgoQU8MatrixMul; | ||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace armv7 | } // namespace armv7 | ||||
| @@ -42,6 +42,7 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8) | |||||
| }; | }; | ||||
| } // namespace armv7 | } // namespace armv7 | ||||
| @@ -27,6 +27,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | ||||
| @@ -37,6 +38,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | ||||
| @@ -48,6 +50,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8) | |||||
| }; | }; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -59,6 +62,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -69,6 +73,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -80,6 +85,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | ||||
| @@ -90,6 +96,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | ||||
| @@ -102,11 +109,18 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| class MatrixMulImpl::AlgoF32Gemv final | class MatrixMulImpl::AlgoF32Gemv final | ||||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
| public: | |||||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||||
| m_handle_type = Handle::HandleType::ARMV7; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV) | |||||
| }; | |||||
| class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| @@ -117,6 +131,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | ||||
| @@ -128,6 +143,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | ||||
| @@ -138,6 +154,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | ||||
| @@ -149,6 +166,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | ||||
| @@ -160,6 +178,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | ||||
| @@ -171,6 +190,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | ||||
| @@ -182,6 +202,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | ||||
| @@ -193,6 +214,7 @@ public: | |||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8) | |||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
| @@ -204,6 +226,7 @@ public: | |||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16) | |||||
| }; | }; | ||||
| } // namespace armv7 | } // namespace armv7 | ||||
| @@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | ||||
| AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
| AlgoPack() { | AlgoPack() { | ||||
| all_algos.emplace_back(&f32_gemv); | |||||
| all_algos.emplace_back(&f32); | |||||
| all_algos.emplace_back(&f32_mk4_pack_4x12); | |||||
| all_algos.emplace_back(&f32_mk4_4x8); | |||||
| m_all_algos.emplace_back(&f32_gemv); | |||||
| m_all_algos.emplace_back(&f32); | |||||
| m_all_algos.emplace_back(&f32_mk4_pack_4x12); | |||||
| m_all_algos.emplace_back(&f32_mk4_4x8); | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| all_algos.emplace_back(&f16_k4x16x1); | |||||
| all_algos.emplace_back(&f16_mk8_4x8); | |||||
| m_all_algos.emplace_back(&f16_k4x16x1); | |||||
| m_all_algos.emplace_back(&f16_mk8_4x8); | |||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||||
| all_algos.emplace_back(&int8_k6x8x4); | |||||
| all_algos.emplace_back(&quint8_k4x8x4); | |||||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||||
| m_all_algos.emplace_back(&int8_k6x8x4); | |||||
| m_all_algos.emplace_back(&quint8_k4x8x4); | |||||
| #endif | #endif | ||||
| all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||||
| all_algos.emplace_back(&int8x8x32_k4x2x16); | |||||
| all_algos.emplace_back(&int8x8x32_k4x8x8); | |||||
| all_algos.emplace_back(&quint8_k4x8x8); | |||||
| all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||||
| all_algos.emplace_back(&int8x8x16_k4x2x16); | |||||
| all_algos.emplace_back(&int8x8x16_k4x8x8); | |||||
| m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||||
| m_all_algos.emplace_back(&int8x8x32_k4x2x16); | |||||
| m_all_algos.emplace_back(&int8x8x32_k4x8x8); | |||||
| m_all_algos.emplace_back(&quint8_k4x8x8); | |||||
| m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||||
| m_all_algos.emplace_back(&int8x8x16_k4x2x16); | |||||
| m_all_algos.emplace_back(&int8x8x16_k4x8x8); | |||||
| m_all_algos.emplace_back(&int16x16x32_k12x4x1); | |||||
| m_all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| all_algos.emplace_back(&int16x16x32_k12x4x1); | |||||
| all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||||
| return m_all_algos; | |||||
| } | } | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
| static AlgoPack s_algo_pack; | |||||
| auto algos = arm_common::MatrixMulImpl::algo_pack(); | |||||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||||
| s_algo_pack.all_algos.end()); | |||||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||||
| MatrixMulImpl::get_all_packed_algo() { | |||||
| auto algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||||
| algo_pack().all_algos().end()); | |||||
| return algos; | return algos; | ||||
| } | } | ||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -25,7 +25,10 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||||
| override; | |||||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||||
| private: | private: | ||||
| class AlgoF32; // Armv7 F32 | class AlgoF32; // Armv7 F32 | ||||
| @@ -52,6 +55,9 @@ private: | |||||
| // DotProduct | // DotProduct | ||||
| #endif | #endif | ||||
| class AlgoPack; | class AlgoPack; | ||||
| public: | |||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| } // namespace armv7 | } // namespace armv7 | ||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * \file dnn/src/common/algo_base.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <functional> | |||||
| #include <string> | |||||
| #include "megdnn/oprs/base.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | |||||
| #define MEGDNN_DECL_ALGO_TYPE(_type) \ | |||||
| uint32_t type() const override { \ | |||||
| return static_cast<std::underlying_type<AlgoType>::type>( \ | |||||
| AlgoType::_type); \ | |||||
| } | |||||
| #define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ | |||||
| static fallback::_opr::AlgoBase* get_algo_from_desc( \ | |||||
| const AlgorithmDesc& desc) | |||||
| #define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||||
| fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \ | |||||
| const AlgorithmDesc& desc) { \ | |||||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||||
| algo_pack().all_algos_map().end()); \ | |||||
| return algo_pack().all_algos_map().at(desc); \ | |||||
| } | |||||
| #define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||||
| _opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||||
| algo_pack().all_algos_map().end()); \ | |||||
| return algo_pack().all_algos_map().at(desc); \ | |||||
| } | |||||
| /** | |||||
| * \brief construct algo from AlgorithmDesc | |||||
| */ | |||||
| template <typename AlgoBase> | |||||
| class AlgoConstructMixin { | |||||
| private: | |||||
| std::vector<std::unique_ptr<AlgoBase>> m_refhold; | |||||
| protected: | |||||
| typename AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| //! construct the algo which described by desc, and return the instance | |||||
| AlgoBase* construct_and_get_algo( | |||||
| const detail::Algorithm::Info::Desc& desc) { | |||||
| auto iter = m_all_algos_map.find(desc); | |||||
| if (iter != m_all_algos_map.end()) { | |||||
| return m_all_algos_map.at(desc); | |||||
| } | |||||
| std::string serialized_bin; | |||||
| AlgoBase::serialize_write_pod(desc.type, serialized_bin); | |||||
| serialized_bin += desc.param; | |||||
| m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin)); | |||||
| m_all_algos_map.emplace(desc, m_refhold.back().get()); | |||||
| return m_refhold.back().get(); | |||||
| } | |||||
| void clear() { | |||||
| m_all_algos_map.clear(); | |||||
| m_refhold.clear(); | |||||
| } | |||||
| const typename AlgoBase::Mapper& all_algos_map() const { | |||||
| return m_all_algos_map; | |||||
| } | |||||
| }; | |||||
| } // namespace megdnn | |||||
| namespace std { | |||||
| template <> | |||||
| struct hash<megdnn::detail::Algorithm::Info::Desc> { | |||||
| std::size_t operator()( | |||||
| const megdnn::detail::Algorithm::Info::Desc& desc) const { | |||||
| return megdnn::hash_combine<size_t>( | |||||
| megdnn::hash_combine<size_t>( | |||||
| std::hash<std::string>()(desc.param), | |||||
| std::hash<uint32_t>()(desc.type)), | |||||
| std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type))); | |||||
| } | |||||
| }; | |||||
| } // namespace std | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -25,15 +25,34 @@ namespace megdnn { | |||||
| */ | */ | ||||
| template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
| typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | ||||
| typename Opr::Algorithm* ret; | |||||
| if (auto set = opr->execution_policy().algorithm) { | |||||
| typename Opr::AlgorithmInfo ret; | |||||
| auto set = opr->execution_policy().algo; | |||||
| if (set.valid()) { | |||||
| ret = set; | ret = set; | ||||
| } else { | } else { | ||||
| ret = opr->get_algorithm_heuristic(std::forward<Args>(args)..., | |||||
| std::numeric_limits<size_t>::max(), | |||||
| false); | |||||
| ret = opr->get_algorithm_info_heuristic( | |||||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||||
| false); | |||||
| } | |||||
| return opr->get_algo_from_desc(ret.desc); | |||||
| } | |||||
| /*! | |||||
| * \brief get user-configured algorithm, or heuristic algorithm. used in opencl | |||||
| * whose algo need to be constructed each time. | |||||
| */ | |||||
| template <class Opr, typename... Args> | |||||
| typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||||
| typename Opr::AlgorithmInfo ret; | |||||
| auto set = opr->execution_policy().algo; | |||||
| if (set.valid()) { | |||||
| return opr->algo_pack().construct_and_get_algo(set.desc); | |||||
| } else { | |||||
| ret = opr->get_algorithm_info_heuristic( | |||||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||||
| false); | |||||
| return opr->get_algo_from_desc(ret.desc); | |||||
| } | } | ||||
| return static_cast<typename Opr::AlgoBase*>(ret); | |||||
| } | } | ||||
| /*! | /*! | ||||
| @@ -9,6 +9,32 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| /** | |||||
| * Boost Software License - Version 1.0 - August 17th, 2003 | |||||
| * | |||||
| * Permission is hereby granted, free of charge, to any person or organization | |||||
| * obtaining a copy of the software and accompanying documentation covered by | |||||
| * this license (the "Software") to use, reproduce, display, distribute, | |||||
| * execute, and transmit the Software, and to prepare derivative works of the | |||||
| * Software, and to permit third-parties to whom the Software is furnished to | |||||
| * do so, all subject to the following: | |||||
| * | |||||
| * The copyright notices in the Software and this entire statement, including | |||||
| * the above license grant, this restriction and the following disclaimer, | |||||
| * must be included in all copies of the Software, in whole or in part, and | |||||
| * all derivative works of the Software, unless such copies or derivative | |||||
| * works are solely in the form of machine-executable object code generated by | |||||
| * a source language processor. | |||||
| * | |||||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
| * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT | |||||
| * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE | |||||
| * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, | |||||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |||||
| * DEALINGS IN THE SOFTWARE. | |||||
| */ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| @@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) { | |||||
| return XXHash64CT::hash(str, count, 20160701); | return XXHash64CT::hash(str, count, 20160701); | ||||
| } | } | ||||
| // refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp | |||||
| template <typename T> | |||||
| inline T hash_combine(T seed, T value) { | |||||
| seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); | |||||
| return seed; | |||||
| } | |||||
| template <typename Vec> | template <typename Vec> | ||||
| std::string vec2str(Vec&& vec) { | std::string vec2str(Vec&& vec) { | ||||
| std::string res; | std::string res; | ||||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||||
| BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { | BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&int8_nchw4_gemm_dotprod); | all_algos.push_back(&int8_nchw4_gemm_dotprod); | ||||
| all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); | all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl) | |||||
| BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; | BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; | ||||
| BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
| @@ -11,13 +11,16 @@ | |||||
| #pragma once | #pragma once | ||||
| #include <csetjmp> | |||||
| #include <unordered_map> | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -26,6 +29,12 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_GEMM_NCHW4_DOTPROD_INT8, | |||||
| CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| BatchConvBiasForwardImpl* opr; | BatchConvBiasForwardImpl* opr; | ||||
| @@ -85,6 +94,7 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8) | |||||
| }; | }; | ||||
| class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final | class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final | ||||
| @@ -99,15 +109,16 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8) | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
| }; | }; | ||||
| class BatchConvBiasForwardImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -116,6 +127,8 @@ public: | |||||
| AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; | AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -26,6 +26,18 @@ public: | |||||
| const TensorLayout& bias, | const TensorLayout& bias, | ||||
| const TensorLayout& z, | const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoInt8NCHW4DotProdGemm; | |||||
| class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| @@ -37,15 +49,6 @@ public: | |||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| bool reproducible) override; | bool reproducible) override; | ||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoInt8NCHW4DotProdGemm; | |||||
| class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| for (auto& algo : brute_force_algos) { | for (auto& algo : brute_force_algos) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| } | } | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -16,6 +16,8 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/batched_matrix_mul/opr_impl.h" | #include "src/cuda/batched_matrix_mul/opr_impl.h" | ||||
| #include "src/cuda/matrix_mul/cublasLt_wrapper.h" | #include "src/cuda/matrix_mul/cublasLt_wrapper.h" | ||||
| #include "src/common/metahelper.h" | |||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| #include <cublasLt.h> | #include <cublasLt.h> | ||||
| #endif | #endif | ||||
| @@ -28,6 +30,14 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_BRUTE_FORCE, | |||||
| CUDA_CUBLAS, | |||||
| CUDA_CUBLASLT, | |||||
| CUDA_INT8X8X32, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| BatchedMatrixMulForwardImpl* opr; | BatchedMatrixMulForwardImpl* opr; | ||||
| @@ -90,6 +100,13 @@ public: | |||||
| void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algorithm, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class BatchedMatrixMulForwardImpl::AlgoCublas final | class BatchedMatrixMulForwardImpl::AlgoCublas final | ||||
| : public BatchedMatrixMulForwardImpl::AlgoBase { | : public BatchedMatrixMulForwardImpl::AlgoBase { | ||||
| @@ -100,6 +117,7 @@ public: | |||||
| void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "CUBLAS"; } | const char* name() const override { return "CUBLAS"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||||
| }; | }; | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { | class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { | ||||
| @@ -110,6 +128,7 @@ public: | |||||
| void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "CUBLAS_LT"; } | const char* name() const override { return "CUBLAS_LT"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final | class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final | ||||
| @@ -121,11 +140,13 @@ public: | |||||
| void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "INT8x8x32"; } | const char* name() const override { return "INT8x8x32"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) | |||||
| }; | }; | ||||
| class BatchedMatrixMulForwardImpl::AlgoPack { | |||||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| MatrixMulForwardImpl::AlgoPack mm_pack; | MatrixMulForwardImpl::AlgoPack mm_pack; | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -137,6 +158,8 @@ public: | |||||
| AlgoInt8x8x32 int8x8x32; | AlgoInt8x8x32 int8x8x32; | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| std::vector<AlgoBruteForce> brute_force_algos; | std::vector<AlgoBruteForce> brute_force_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | |||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| MatrixMulForwardImpl mm{args.opr->handle()}; | MatrixMulForwardImpl mm{args.opr->handle()}; | ||||
| mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | ||||
| mm.execution_policy() = {m_algorithm}; | |||||
| mm.execution_policy() = {m_algorithm->info()}; | |||||
| auto mm_layout_a = args.layout_a.remove_axis(0); | auto mm_layout_a = args.layout_a.remove_axis(0); | ||||
| auto mm_layout_b = args.layout_b.remove_axis(0); | auto mm_layout_b = args.layout_b.remove_axis(0); | ||||
| @@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | |||||
| auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | ||||
| mm_opr->param() = {args.opr->param().transposeA, | mm_opr->param() = {args.opr->param().transposeA, | ||||
| args.opr->param().transposeB}; | args.opr->param().transposeB}; | ||||
| mm_opr->execution_policy() = {m_algorithm}; | |||||
| mm_opr->execution_policy() = {m_algorithm->info()}; | |||||
| return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | ||||
| args.layout_c); | args.layout_c); | ||||
| @@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||||
| auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | ||||
| mm_opr->param() = {args.opr->param().transposeA, | mm_opr->param() = {args.opr->param().transposeA, | ||||
| args.opr->param().transposeB}; | args.opr->param().transposeB}; | ||||
| mm_opr->execution_policy() = {m_algorithm}; | |||||
| mm_opr->execution_policy() = {m_algorithm->info()}; | |||||
| rep(n, N) { | rep(n, N) { | ||||
| TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
| auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | ||||
| @@ -32,6 +32,16 @@ public: | |||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, | size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| const char* get_algorithm_set_name() const override { | |||||
| return "BATCHED_MATMUL"; | |||||
| } | |||||
| bool is_thread_safe() const override { return true; } | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| @@ -40,12 +50,6 @@ public: | |||||
| const TensorLayout& C, | const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| bool reproducible) override; | bool reproducible) override; | ||||
| const char* get_algorithm_set_name() const override { | |||||
| return "BATCHED_MATMUL"; | |||||
| } | |||||
| bool is_thread_safe() const override { return true; } | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
| for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | ||||
| non_cudnn_algos.push_back(all_algos[i]); | non_cudnn_algos.push_back(all_algos[i]); | ||||
| } | } | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; | ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl) | |||||
| ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
| ConvBiasForwardImpl* o, const TensorLayout& src, | ConvBiasForwardImpl* o, const TensorLayout& src, | ||||
| const TensorLayout& filter, const TensorLayout& bias, | const TensorLayout& filter, const TensorLayout& bias, | ||||
| @@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn_conv_bias_activations.push_back( \ | |||||
| {REPROD, \ | |||||
| "CUDNN:ConvBiasActivation:" #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}); \ | |||||
| cudnn_convs.push_back( \ | |||||
| {REPROD, \ | |||||
| "CUDNN:Convolution:" #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||||
| #if CUDNN_MAJOR >= 5 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true); | |||||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true); | |||||
| #endif | |||||
| #endif | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | |||||
| cudnn_conv_bias_activations.push_back(algo.first); | |||||
| cudnn_convs.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| #if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
| @@ -6,19 +6,23 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/common/metahelper.h" | |||||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | #include "src/cuda/conv_bias/conv_bias_int8.cuh" | ||||
| #include "src/cuda/conv_bias/helper.h" | #include "src/cuda/conv_bias/helper.h" | ||||
| #include "src/cuda/conv_bias/opr_impl.h" | #include "src/cuda/conv_bias/opr_impl.h" | ||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| #include <cuda.h> | #include <cuda.h> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -38,11 +42,39 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_CUDNN_CONVBIAS, | |||||
| CUDA_CHANWISE, | |||||
| CUDA_CHANWISE_SMALL, | |||||
| CUDA_CHANWISE_INT8X8X32, | |||||
| CUDA_CUDNN_CONV, | |||||
| CUDA_INPLACE_MATMUL, | |||||
| CUDA_MATMUL, | |||||
| CUDA_MATMUL_INT8X8X32, | |||||
| CUDA_1X1, | |||||
| CUDA_BATCHED_MATMUL, | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| CUDA_WMMA_UINT4X4X32, | |||||
| CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8, | |||||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||||
| CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8, | |||||
| CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8, | |||||
| CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, | |||||
| CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, | |||||
| CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, | |||||
| CUDA_BFLOAT16, | |||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, | |||||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | |||||
| CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, | |||||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | ||||
| ConvBiasForwardImpl* opr; | ConvBiasForwardImpl* opr; | ||||
| const PreprocessedFilter* preprocessed_filter; | const PreprocessedFilter* preprocessed_filter; | ||||
| std::string to_string() const; | std::string to_string() const; | ||||
| SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, | SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, | ||||
| const TensorLayout& filter, const TensorLayout& bias, | const TensorLayout& filter, const TensorLayout& bias, | ||||
| @@ -80,13 +112,17 @@ public: | |||||
| virtual void exec(const ExecArgs& args) const = 0; | virtual void exec(const ExecArgs& args) const = 0; | ||||
| virtual size_t get_preprocess_workspace_in_bytes( | virtual size_t get_preprocess_workspace_in_bytes( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| MEGDNN_MARK_USED_VAR(args); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| MEGDNN_MARK_USED_VAR(args); | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| virtual void exec_preprocess(const ExecArgs& args) const {} | |||||
| virtual void exec_preprocess(const ExecArgs& args) const { | |||||
| MEGDNN_MARK_USED_VAR(args); | |||||
| } | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
| @@ -114,11 +150,14 @@ public: | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { | class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name, | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum) | |||||
| : m_is_reproducible(is_reproducible), | |||||
| m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||||
| m_cudnn_enum(cudnn_enum) {} | |||||
| AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv_fwd_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||||
| m_name = ConvBiasForward::algo_name<DefaultParam>( | |||||
| "CUDNN:ConvBiasActivation:" + m_attr.name, {}); | |||||
| } | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| @@ -127,16 +166,24 @@ public: | |||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| bool is_reproducible() const override { return m_is_reproducible; } | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | ||||
| bool is_cudnn() const override { return true; } | bool is_cudnn() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| bool m_is_reproducible; | |||||
| std::string m_name; | std::string m_name; | ||||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | cudnnConvolutionFwdAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | ||||
| @@ -154,6 +201,8 @@ public: | |||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| private: | private: | ||||
| mutable std::string m_name; | mutable std::string m_name; | ||||
| }; | }; | ||||
| @@ -172,6 +221,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||||
| private: | private: | ||||
| mutable std::string m_name; | mutable std::string m_name; | ||||
| @@ -190,6 +240,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) | |||||
| private: | private: | ||||
| mutable std::string m_name; | mutable std::string m_name; | ||||
| @@ -197,27 +248,39 @@ private: | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { | class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoCUDNNConv(bool is_reproducible, const char* name, | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum) | |||||
| : m_is_reproducible(is_reproducible), | |||||
| m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||||
| m_cudnn_enum(cudnn_enum) {} | |||||
| AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv_fwd_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||||
| m_name = ConvBiasForward::algo_name<DefaultParam>( | |||||
| "CUDNN:Convolution:" + m_attr.name, {}); | |||||
| } | |||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { return m_is_reproducible; } | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
| bool is_cudnn() const override { return true; } | bool is_cudnn() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| bool m_is_reproducible; | |||||
| std::string m_name; | std::string m_name; | ||||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | cudnnConvolutionFwdAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| }; | }; | ||||
| @@ -237,6 +300,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||||
| private: | private: | ||||
| mutable std::string m_name; | mutable std::string m_name; | ||||
| @@ -261,6 +325,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| @@ -281,6 +346,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) | |||||
| private: | private: | ||||
| bool need_src_unroll(const SizeArgs& args) const; | bool need_src_unroll(const SizeArgs& args) const; | ||||
| @@ -310,6 +376,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_1X1) | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| @@ -333,6 +400,7 @@ public: | |||||
| return m_name.c_str(); | return m_name.c_str(); | ||||
| } | } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| @@ -354,6 +422,13 @@ public: | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
| TensorLayout& dst_pg, TensorLayout& bias_pg); | TensorLayout& dst_pg, TensorLayout& bias_pg); | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| @@ -370,10 +445,13 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| const char* name() const override { return "QUINT4x4x32_WMMA"; } | const char* name() const override { return "QUINT4x4x32_WMMA"; } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; | |||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||||
| const SizeArgs& args) const; | |||||
| bool use_kernel_fhxfw(const SizeArgs& args) const; | bool use_kernel_fhxfw(const SizeArgs& args) const; | ||||
| size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; | size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -395,6 +473,7 @@ public: | |||||
| const convolution::ConvParam& param, float alpha, float beta, | const convolution::ConvParam& param, float alpha, float beta, | ||||
| float gamma, float scale, cudaStream_t stream, | float gamma, float scale, cudaStream_t stream, | ||||
| param::ConvBias::NonlineMode nonlinear_mode); | param::ConvBias::NonlineMode nonlinear_mode); | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | ||||
| @@ -415,8 +494,9 @@ public: | |||||
| warp_k == 32 && stage == 2) { | warp_k == 32 && stage == 2) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||||
| threadblock_k, warp_m, warp_n, warp_k, stage); | |||||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||||
| warp_k, stage); | |||||
| } | } | ||||
| }; | }; | ||||
| AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | ||||
| @@ -433,6 +513,13 @@ public: | |||||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | ||||
| const SizeArgs& args) const override; | const SizeArgs& args) const override; | ||||
| void exec_preprocess(const ExecArgs& args) const override; | void exec_preprocess(const ExecArgs& args) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| @@ -457,9 +544,7 @@ public: | |||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| template <typename BiasVisitor> | template <typename BiasVisitor> | ||||
| static void dispatch_nonlinear_mode( | static void dispatch_nonlinear_mode( | ||||
| @@ -471,6 +556,14 @@ public: | |||||
| MMATileSize mma_tile_size); | MMATileSize mma_tile_size); | ||||
| static std::string to_string(MMATileSize mma_tile_size); | static std::string to_string(MMATileSize mma_tile_size); | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_mma_tile_size, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MMATileSize m_mma_tile_size; | MMATileSize m_mma_tile_size; | ||||
| std::string m_name; | std::string m_name; | ||||
| @@ -488,10 +581,16 @@ public: | |||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_mma_tile_size, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
| @@ -513,6 +612,13 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_mma_tile_size, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MMATileSize m_mma_tile_size; | MMATileSize m_mma_tile_size; | ||||
| @@ -533,6 +639,13 @@ public: | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_mma_tile_size, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MMATileSize m_mma_tile_size; | MMATileSize m_mma_tile_size; | ||||
| @@ -570,6 +683,13 @@ public: | |||||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | ||||
| const SizeArgs& args) const override; | const SizeArgs& args) const override; | ||||
| void exec_preprocess(const ExecArgs& args) const override; | void exec_preprocess(const ExecArgs& args) const override; | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algo_param, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
| @@ -592,6 +712,14 @@ public: | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | bool is_reproducible() const override { return m_impl->is_reproducible(); } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | ||||
| TensorLayout& fsrc, TensorLayout& ffilter, | TensorLayout& fsrc, TensorLayout& ffilter, | ||||
| @@ -603,17 +731,16 @@ private: | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| std::vector<AlgoBase*> all_algos, | std::vector<AlgoBase*> all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos, | |||||
| bfloat16_algos; | |||||
| non_cudnn_algos, bfloat16_algos; | |||||
| std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | ||||
| std::vector<AlgoCUDNNConv> cudnn_convs; | std::vector<AlgoCUDNNConv> cudnn_convs; | ||||
| AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
| @@ -646,6 +773,8 @@ public: | |||||
| AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); | AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| private: | private: | ||||
| #if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
| void fill_imma_algos(); | void fill_imma_algos(); | ||||
| @@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||||
| change_dtype(fdst); | change_dtype(fdst); | ||||
| opr->param() = args.opr->param(); | opr->param() = args.opr->param(); | ||||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| opr->execution_policy() = {m_impl}; | |||||
| opr->execution_policy() = {m_impl->info()}; | |||||
| return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | ||||
| } | } | ||||
| @@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | auto convbias_opr = args.handle->create_operator<ConvBias>(); | ||||
| convbias_opr->param() = args.opr->param(); | convbias_opr->param() = args.opr->param(); | ||||
| convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| convbias_opr->execution_policy() = {m_impl}; | |||||
| convbias_opr->execution_policy() = {m_impl->info()}; | |||||
| convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | ||||
| fdst_tensor, nullptr, cvter.workspace()); | fdst_tensor, nullptr, cvter.workspace()); | ||||
| } | } | ||||
| @@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| auto conv_args = args; | auto conv_args = args; | ||||
| auto cudnn_conv_bias_act_from_enum_wrapper = | auto cudnn_conv_bias_act_from_enum_wrapper = | ||||
| [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||||
| [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||||
| return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); | return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); | ||||
| }; | }; | ||||
| auto cudnn_conv_from_enum_wrapper = | auto cudnn_conv_from_enum_wrapper = | ||||
| [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||||
| [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||||
| return sm_algo_pack.cudnn_conv_from_enum(algo); | return sm_algo_pack.cudnn_conv_from_enum(algo); | ||||
| }; | }; | ||||
| @@ -24,17 +24,6 @@ public: | |||||
| _megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
| const PreprocessedFilter* preprocessed_filter, | const PreprocessedFilter* preprocessed_filter, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&, | const TensorLayout&, | ||||
| @@ -80,6 +69,20 @@ public: | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(bfloat16_refhold.back().get()); | all_algos.push_back(bfloat16_refhold.back().get()); | ||||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | bfloat16_algos.push_back(bfloat16_refhold.back().get()); | ||||
| } | } | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||||
| ConvolutionBackwardDataImpl::AlgoCUDNN* | ConvolutionBackwardDataImpl::AlgoCUDNN* | ||||
| ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( | ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( | ||||
| cudnnConvolutionBwdDataAlgo_t algo) { | cudnnConvolutionBwdDataAlgo_t algo) { | ||||
| @@ -11,8 +11,11 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/cuda/convolution/helper.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/cuda/convolution/helper.h" | |||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -23,154 +26,146 @@ namespace cuda { | |||||
| * All the algo impls should try to support non-contiguous batch dim, for group | * All the algo impls should try to support non-contiguous batch dim, for group | ||||
| * conv execution. | * conv execution. | ||||
| */ | */ | ||||
| class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl *handle; | |||||
| CanonizedFilterMeta filter_meta; | |||||
| const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||||
| ConvolutionBackwardDataImpl *opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | |||||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||||
| } | |||||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad); | |||||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||||
| const TensorLayout& filter, | |||||
| const CanonizedFilterMeta& filter_meta, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, grad_layout, filter_layout, filter_meta, | |||||
| diff_layout}; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs: public SizeArgs { | |||||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(ConvolutionBackwardDataImpl *opr, | |||||
| _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||||
| virtual void exec(const ExecArgs &args) const = 0; | |||||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace( | |||||
| const SizeArgs &args, const Workspace &workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd data algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_CUDNN, | |||||
| CUDA_MATMUL, | |||||
| CUDA_CHANWISE, | |||||
| CUDA_CHANWISE_SMALL, | |||||
| CUDA_BFLOAT16, | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| CanonizedFilterMeta filter_meta; | |||||
| const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||||
| ConvolutionBackwardDataImpl* opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution::CUDNNBwdDataDescs& desc) const { | |||||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||||
| } | } | ||||
| virtual bool is_cudnn() const { | |||||
| return false; | |||||
| SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||||
| const CanonizedFilterMeta& filter_meta, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, grad_layout, filter_layout, filter_meta, | |||||
| diff_layout}; | |||||
| } | } | ||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd data algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { return false; } | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | ||||
| bool m_is_reproducible; | |||||
| const char *m_name; | |||||
| cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| public: | |||||
| public: | |||||
| AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv_bwd_data_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum); | |||||
| } | |||||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||||
| m_is_reproducible(is_reproducible), | |||||
| m_name(name), | |||||
| m_cudnn_enum(cudnn_enum) | |||||
| {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| bool is_reproducible() const override { | |||||
| return m_is_reproducible; | |||||
| } | |||||
| const char* name() const override { return m_attr.name.c_str(); } | |||||
| const char* name() const override { | |||||
| return m_name; | |||||
| } | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||||
| return m_cudnn_enum; | |||||
| } | |||||
| bool is_cudnn() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||||
| bool is_cudnn() const override { | |||||
| return true; | |||||
| } | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| //! im2col and matmul, with dilation | //! im2col and matmul, with dilation | ||||
| class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase { | |||||
| template<typename T> | |||||
| static void exec_internal(const ExecArgs &args); | |||||
| class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase { | |||||
| template <typename T> | |||||
| static void exec_internal(const ExecArgs& args); | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "MATMUL"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "MATMUL"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "CHANNEL_WISE"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE_SMALL"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "CHANNEL_WISE_SMALL"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | ||||
| @@ -190,61 +185,72 @@ private: | |||||
| TensorLayout& fsrc, TensorLayout& ffilter, | TensorLayout& fsrc, TensorLayout& ffilter, | ||||
| TensorLayout& fdst) const; | TensorLayout& fdst) const; | ||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algorithm, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| //! implement group conv by another algo | //! implement group conv by another algo | ||||
| class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||||
| AlgoBase *m_impl; | |||||
| class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final | |||||
| : public AlgoBase { | |||||
| AlgoBase* m_impl; | |||||
| std::string m_name; | std::string m_name; | ||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { | |||||
| return m_impl->is_reproducible(); | |||||
| } | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||||
| TensorLayout& grad_pg); | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| static void modify_size_args(SizeArgs &args, | |||||
| TensorLayout &diff_pg, TensorLayout &grad_pg); | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoPack { | |||||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| // defined in cudnn.cpp | // defined in cudnn.cpp | ||||
| void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| public: | |||||
| AlgoPack(); | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoMatmul matmul; | |||||
| AlgoChanwise chanwise; | |||||
| AlgoChanwiseSmall chanwise_small; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoMatmul matmul; | |||||
| AlgoChanwise chanwise; | |||||
| AlgoChanwiseSmall chanwise_small; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
| std::vector<AlgoBase*> | |||||
| std::vector<AlgoBase*> | |||||
| //! all algorithms | //! all algorithms | ||||
| all_algos, | all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos, | |||||
| bfloat16_algos; | |||||
| non_cudnn_algos, bfloat16_algos; | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||||
| change_dtype(fgrad); | change_dtype(fgrad); | ||||
| opr->param() = args.opr->param(); | opr->param() = args.opr->param(); | ||||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| opr->execution_policy() = {m_algorithm}; | |||||
| opr->execution_policy() = {m_algorithm->info()}; | |||||
| return SizeArgs(opr, ffilter, fdiff, fgrad); | return SizeArgs(opr, ffilter, fdiff, fgrad); | ||||
| } | } | ||||
| @@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||||
| args.handle->create_operator<ConvolutionBackwardData>(); | args.handle->create_operator<ConvolutionBackwardData>(); | ||||
| conv_back_data_opr->param() = args.opr->param(); | conv_back_data_opr->param() = args.opr->param(); | ||||
| conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| conv_back_data_opr->execution_policy() = {m_algorithm}; | |||||
| conv_back_data_opr->execution_policy() = {m_algorithm->info()}; | |||||
| conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | ||||
| cvter.workspace()); | cvter.workspace()); | ||||
| } | } | ||||
| @@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec( | |||||
| } | } | ||||
| void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn.push_back({ \ | |||||
| REPROD, #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||||
| "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||||
| #if CUDNN_MAJOR >= 5 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true); | |||||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true); | |||||
| #endif | |||||
| #endif | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) { | |||||
| cudnn.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(bfloat16_refhold.back().get()); | all_algos.push_back(bfloat16_refhold.back().get()); | ||||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | bfloat16_algos.push_back(bfloat16_refhold.back().get()); | ||||
| } | } | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) | |||||
| ConvolutionBackwardFilterImpl::AlgoCUDNN* | ConvolutionBackwardFilterImpl::AlgoCUDNN* | ||||
| ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( | ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( | ||||
| cudnnConvolutionBwdFilterAlgo_t algo) { | cudnnConvolutionBwdFilterAlgo_t algo) { | ||||
| @@ -6,13 +6,16 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/cuda/convolution/helper.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/cuda/convolution/helper.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -23,141 +26,134 @@ namespace cuda { | |||||
| * All the algo impls should try to support non-contiguous batch dim, for group | * All the algo impls should try to support non-contiguous batch dim, for group | ||||
| * conv execution. | * conv execution. | ||||
| */ | */ | ||||
| class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl *handle; | |||||
| const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||||
| CanonizedFilterMeta grad_filter_meta; | |||||
| ConvolutionBackwardFilterImpl *opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution::CUDNNBwdFilterDescs &desc) const { | |||||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||||
| opr->param()); | |||||
| } | |||||
| SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||||
| const TensorLayout &src, const TensorLayout &diff, | |||||
| const TensorLayout &grad); | |||||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| const CanonizedFilterMeta& grad_meta); | |||||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, src_layout, grad_layout, grad_filter_meta, | |||||
| diff_layout}; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs: public SizeArgs { | |||||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(ConvolutionBackwardFilterImpl *opr, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||||
| virtual void exec(const ExecArgs &args) const = 0; | |||||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_CUDNN, | |||||
| CUDA_MATMUL, | |||||
| CUDA_CHANWISE, | |||||
| CUDA_BFLOAT16, | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||||
| CanonizedFilterMeta grad_filter_meta; | |||||
| ConvolutionBackwardFilterImpl* opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution::CUDNNBwdFilterDescs& desc) const { | |||||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||||
| } | } | ||||
| AlgoBase& check_workspace( | |||||
| const SizeArgs &args, const Workspace &workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd filter algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { | |||||
| return false; | |||||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| const CanonizedFilterMeta& grad_meta); | |||||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, src_layout, grad_layout, grad_filter_meta, | |||||
| diff_layout}; | |||||
| } | } | ||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd filter algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { return false; } | |||||
| }; | }; | ||||
| class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | ||||
| bool m_is_reproducible; | |||||
| const char *m_name; | |||||
| cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| public: | |||||
| public: | |||||
| AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv_bwd_flt_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum); | |||||
| } | |||||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||||
| m_is_reproducible(is_reproducible), | |||||
| m_name(name), | |||||
| m_cudnn_enum(cudnn_enum) | |||||
| {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| bool is_reproducible() const override { | |||||
| return m_is_reproducible; | |||||
| } | |||||
| const char* name() const override { return m_attr.name.c_str(); } | |||||
| const char* name() const override { | |||||
| return m_name; | |||||
| } | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||||
| return m_cudnn_enum; | |||||
| } | |||||
| bool is_cudnn() const override { return true; } | |||||
| bool is_cudnn() const override { | |||||
| return true; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| //! im2col and matmul, with dilation | //! im2col and matmul, with dilation | ||||
| class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase { | |||||
| template<typename T> | |||||
| static void exec_internal(const ExecArgs &args); | |||||
| class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase { | |||||
| template <typename T> | |||||
| static void exec_internal(const ExecArgs& args); | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "MATMUL"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "MATMUL"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| }; | }; | ||||
| class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "CHANNEL_WISE"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| }; | }; | ||||
| class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | ||||
| @@ -169,6 +165,13 @@ public: | |||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algorithm, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| std::string m_name; | std::string m_name; | ||||
| @@ -180,57 +183,62 @@ private: | |||||
| }; | }; | ||||
| //! implement group conv by another algo | //! implement group conv by another algo | ||||
| class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||||
| AlgoBase *m_impl; | |||||
| class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final | |||||
| : public AlgoBase { | |||||
| AlgoBase* m_impl; | |||||
| std::string m_name; | std::string m_name; | ||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||||
| TensorLayout& diff_pg); | |||||
| bool is_reproducible() const override { | |||||
| return m_impl->is_reproducible(); | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| static void modify_size_args(SizeArgs &args, | |||||
| TensorLayout &src_pg, TensorLayout &diff_pg); | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class ConvolutionBackwardFilterImpl::AlgoPack { | |||||
| class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||||
| // defined in cudnn.cpp | // defined in cudnn.cpp | ||||
| void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| public: | |||||
| AlgoPack(); | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoMatmul matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoMatmul matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||||
| std::vector<AlgoBase*> | |||||
| std::vector<AlgoBase*> | |||||
| //! all algorithms | //! all algorithms | ||||
| all_algos, | all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos, | |||||
| bfloat16_algos; | |||||
| non_cudnn_algos, bfloat16_algos; | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||||
| change_dtype(fgrad); | change_dtype(fgrad); | ||||
| opr->param() = args.opr->param(); | opr->param() = args.opr->param(); | ||||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| opr->execution_policy() = {m_algorithm}; | |||||
| opr->execution_policy() = {m_algorithm->info()}; | |||||
| return SizeArgs(opr, fsrc, fdiff, fgrad); | return SizeArgs(opr, fsrc, fdiff, fgrad); | ||||
| } | } | ||||
| @@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||||
| conv_back_filter_opr->param() = args.opr->param(); | conv_back_filter_opr->param() = args.opr->param(); | ||||
| conv_back_filter_opr->param().compute_mode = | conv_back_filter_opr->param().compute_mode = | ||||
| Param::ComputeMode::DEFAULT; | Param::ComputeMode::DEFAULT; | ||||
| conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||||
| conv_back_filter_opr->execution_policy() = {m_algorithm->info()}; | |||||
| conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | ||||
| cvter.workspace()); | cvter.workspace()); | ||||
| } | } | ||||
| @@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec( | |||||
| } | } | ||||
| void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn.push_back({ \ | |||||
| REPROD, #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||||
| "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||||
| #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true); | |||||
| #if CUDNN_MAJOR >= 6 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true); | |||||
| #endif | |||||
| #endif | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { | |||||
| cudnn.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, | |||||
| conv_param.dilate_w, | conv_param.dilate_w, | ||||
| 0, | 0, | ||||
| conv_param.compute_mode}; | conv_param.compute_mode}; | ||||
| ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm}; | |||||
| ret.convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
| CUDNNBwdDataDescs desc; | CUDNNBwdDataDescs desc; | ||||
| args.init_desc(desc); | args.init_desc(desc); | ||||
| //disable, segfault in megbrain, need further investigate. | |||||
| #if 0 | |||||
| bool is_heuristic_success= convolution:: | |||||
| PerformanceModelBackwardData::get_algo_backward_data_success( | |||||
| args, desc, workspace_limit_in_bytes, &algo); | |||||
| if (is_heuristic_success) { | |||||
| return sm_algo_pack.cudnn_from_enum(algo); | |||||
| } | |||||
| #endif | |||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | ||||
| @@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
| const PreprocessedFilter* preprocessed_filter, | const PreprocessedFilter* preprocessed_filter, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
| const TensorLayout &filter, | |||||
| const TensorLayout &dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| @@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
| TensorLayout bias_layout; | TensorLayout bias_layout; | ||||
| TensorLayout z_layout; | TensorLayout z_layout; | ||||
| }; | }; | ||||
| private: | |||||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&); | |||||
| }; | |||||
| class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
| public: | |||||
| using ConvolutionBackwardData::ConvolutionBackwardData; | |||||
| void exec(_megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| bool reproducible) override; | bool reproducible) override; | ||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, | |||||
| const CanonizedFilterMeta& filter_meta, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, bool reproducible); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoChanwiseSmall; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoBFloat16; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { | |||||
| return sm_algo_pack; | |||||
| } | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | |||||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&); | |||||
| }; | }; | ||||
| class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
| public: | |||||
| using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||||
| void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& gradk, | |||||
| const CanonizedFilterMeta& grad_meta, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoBFloat16; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { | |||||
| return sm_algo_pack; | |||||
| } | |||||
| class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { | |||||
| public: | |||||
| using ConvolutionBackwardData::ConvolutionBackwardData; | |||||
| void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||||
| return get_algorithm_heuristic(filter, filter_meta, diff, grad, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoChanwiseSmall; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoBFloat16; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const CanonizedFilterMeta& filter_meta, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | |||||
| private: | |||||
| static AlgoPack sm_algo_pack; | |||||
| class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { | |||||
| public: | |||||
| using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| const CanonizedFilterMeta& grad_meta, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) { | |||||
| return get_algorithm_heuristic(src, diff, grad, grad_meta, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoBFloat16; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| const CanonizedFilterMeta& grad_meta, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&i); | all_algos.push_back(&i); | ||||
| } | } | ||||
| megdnn_assert(all_algos_data == all_algos.data()); | megdnn_assert(all_algos_data == all_algos.data()); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl) | |||||
| Convolution3DBackwardDataImpl::AlgoCUDNN* | Convolution3DBackwardDataImpl::AlgoCUDNN* | ||||
| Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( | Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( | ||||
| cudnnConvolutionBwdDataAlgo_t algo) { | cudnnConvolutionBwdDataAlgo_t algo) { | ||||
| @@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const | |||||
| fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], | fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], | ||||
| diff_layout->to_string().c_str(), | diff_layout->to_string().c_str(), | ||||
| grad_layout->to_string().c_str(), | grad_layout->to_string().c_str(), | ||||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||||
| fm.stride[0], fm.stride[1], fm.stride[2], | fm.stride[0], fm.stride[1], fm.stride[2], | ||||
| fm.dilation[0], fm.dilation[1] ,fm.dilation[2], | fm.dilation[0], fm.dilation[1] ,fm.dilation[2], | ||||
| !fm.should_flip, | !fm.should_flip, | ||||
| @@ -6,13 +6,16 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/cuda/convolution3d/helper.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/cuda/convolution3d/helper.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -23,170 +26,174 @@ namespace cuda { | |||||
| * All the algo impls should try to support non-contiguous batch dim, for group | * All the algo impls should try to support non-contiguous batch dim, for group | ||||
| * conv execution. | * conv execution. | ||||
| */ | */ | ||||
| class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl *handle; | |||||
| CanonizedFilterMeta filter_meta; | |||||
| const TensorLayout *diff_layout, *grad_layout; | |||||
| Convolution3DBackwardDataImpl *opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const { | |||||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||||
| } | |||||
| SizeArgs(Convolution3DBackwardDataImpl *opr, | |||||
| const TensorLayout &filter, const TensorLayout &diff, | |||||
| const TensorLayout &grad); | |||||
| SizeArgs(Convolution3DBackwardDataImpl *opr, | |||||
| const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||||
| const TensorLayout &grad); | |||||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, grad_layout, filter_meta, diff_layout, | |||||
| opr->param().data_type}; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs: public SizeArgs { | |||||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DBackwardDataImpl *opr, | |||||
| _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||||
| virtual void exec(const ExecArgs &args) const = 0; | |||||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace( | |||||
| const SizeArgs &args, const Workspace &workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd data algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| CUDA_CUDNN, | |||||
| CUDA_CHANWISE, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| CanonizedFilterMeta filter_meta; | |||||
| const TensorLayout *diff_layout, *grad_layout; | |||||
| Convolution3DBackwardDataImpl* opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const { | |||||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||||
| } | } | ||||
| virtual bool is_cudnn() const { | |||||
| return false; | |||||
| SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| SizeArgs(Convolution3DBackwardDataImpl* opr, | |||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad); | |||||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, grad_layout, filter_meta, diff_layout, | |||||
| opr->param().data_type}; | |||||
| } | } | ||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd data algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { return false; } | |||||
| }; | }; | ||||
| class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | ||||
| bool m_is_reproducible; | |||||
| const char *m_name; | |||||
| cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| public: | |||||
| public: | |||||
| AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv3d_bwd_data_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum); | |||||
| } | |||||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||||
| m_is_reproducible(is_reproducible), | |||||
| m_name(name), | |||||
| m_cudnn_enum(cudnn_enum) | |||||
| {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| bool is_reproducible() const override { | |||||
| return m_is_reproducible; | |||||
| } | |||||
| const char* name() const override { return m_attr.name.c_str(); } | |||||
| const char* name() const override { | |||||
| return m_name; | |||||
| } | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||||
| return m_cudnn_enum; | |||||
| } | |||||
| bool is_cudnn() const override { return true; } | |||||
| bool is_cudnn() const override { | |||||
| return true; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "CHANNEL_WISE"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| }; | }; | ||||
| //! implement group conv by another algo | //! implement group conv by another algo | ||||
| class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||||
| AlgoBase *m_impl; | |||||
| class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final | |||||
| : public AlgoBase { | |||||
| AlgoBase* m_impl; | |||||
| std::string m_name; | std::string m_name; | ||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { | |||||
| return m_impl->is_reproducible(); | |||||
| } | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||||
| TensorLayout& grad_pg); | |||||
| static void modify_size_args(SizeArgs &args, | |||||
| TensorLayout &diff_pg, TensorLayout &grad_pg); | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DBackwardDataImpl::AlgoPack { | |||||
| class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| // defined in cudnn.cpp | // defined in cudnn.cpp | ||||
| void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| public: | |||||
| AlgoPack(); | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoBase*> | |||||
| std::vector<AlgoBase*> | |||||
| //! all algorithms | //! all algorithms | ||||
| all_algos, | all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos; | non_cudnn_algos; | ||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec( | |||||
| } | } | ||||
| void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn.push_back({ \ | |||||
| REPROD, #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||||
| "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) { | |||||
| cudnn.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -17,7 +17,7 @@ using namespace cuda; | |||||
| Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | ||||
| non_cudnn_algos.push_back(&chanwise); | non_cudnn_algos.push_back(&chanwise); | ||||
| non_cudnn_algos.push_back(&inplace_matmul); | |||||
| non_cudnn_algos.push_back(&inplace_matmul); | |||||
| all_algos.push_back(&chanwise); // prefer chanwise | all_algos.push_back(&chanwise); // prefer chanwise | ||||
| fill_cudnn_algos(); | fill_cudnn_algos(); | ||||
| @@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | |||||
| } | } | ||||
| megdnn_assert(all_algos_data == all_algos.data()); | megdnn_assert(all_algos_data == all_algos.data()); | ||||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul | non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl) | |||||
| Convolution3DBackwardFilterImpl::AlgoCUDNN* | Convolution3DBackwardFilterImpl::AlgoCUDNN* | ||||
| Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( | Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( | ||||
| cudnnConvolutionBwdFilterAlgo_t algo) { | cudnnConvolutionBwdFilterAlgo_t algo) { | ||||
| @@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | ||||
| src_layout->to_string().c_str(), | src_layout->to_string().c_str(), | ||||
| diff_layout->to_string().c_str(), | diff_layout->to_string().c_str(), | ||||
| fm.group, fm.ocpg, fm.icpg, | |||||
| fm.group, fm.ocpg, fm.icpg, | |||||
| fm.spatial[0], fm.spatial[1], fm.spatial[2], | fm.spatial[0], fm.spatial[1], fm.spatial[2], | ||||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||||
| fm.stride[0], fm.stride[1], fm.stride[2], | fm.stride[0], fm.stride[1], fm.stride[2], | ||||
| fm.dilation[0], fm.dilation[1], fm.dilation[2], | fm.dilation[0], fm.dilation[1], fm.dilation[2], | ||||
| !fm.should_flip, | !fm.should_flip, | ||||
| @@ -6,198 +6,198 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/cuda/convolution3d/helper.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "src/cuda/convolution3d/helper.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs { | |||||
| HandleImpl *handle; | |||||
| const TensorLayout *src_layout, *diff_layout; | |||||
| CanonizedFilterMeta grad_filter_meta; | |||||
| Convolution3DBackwardFilterImpl *opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const { | |||||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||||
| opr->param()); | |||||
| } | |||||
| SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||||
| const TensorLayout &src, const TensorLayout &diff, | |||||
| const TensorLayout &grad); | |||||
| SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||||
| const TensorLayout &src, const TensorLayout &diff, | |||||
| const CanonizedFilterMeta &grad); | |||||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, src_layout, grad_filter_meta, diff_layout, | |||||
| opr->param().data_type}; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs: public SizeArgs { | |||||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DBackwardFilterImpl *opr, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||||
| virtual void exec(const ExecArgs &args) const = 0; | |||||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd filter algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| CUDA_CUDNN, | |||||
| CUDA_INPLACE_MATMUL, | |||||
| CUDA_CHANWISE, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| struct SizeArgs { | |||||
| HandleImpl* handle; | |||||
| const TensorLayout *src_layout, *diff_layout; | |||||
| CanonizedFilterMeta grad_filter_meta; | |||||
| Convolution3DBackwardFilterImpl* opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const { | |||||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||||
| } | } | ||||
| SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& diff, const TensorLayout& grad); | |||||
| SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& diff, const CanonizedFilterMeta& grad); | |||||
| virtual bool is_cudnn() const { | |||||
| return false; | |||||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||||
| return {handle, src_layout, grad_filter_meta, diff_layout, | |||||
| opr->param().data_type}; | |||||
| } | } | ||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv bwd filter algo %s: " | |||||
| "required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { return false; } | |||||
| }; | }; | ||||
| class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | ||||
| bool m_is_reproducible; | |||||
| const char *m_name; | |||||
| cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| public: | |||||
| public: | |||||
| AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv3d_bwd_flt_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum); | |||||
| } | |||||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||||
| m_is_reproducible(is_reproducible), | |||||
| m_name(name), | |||||
| m_cudnn_enum(cudnn_enum) | |||||
| {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| bool is_reproducible() const override { | |||||
| return m_is_reproducible; | |||||
| } | |||||
| const char* name() const override { return m_attr.name.c_str(); } | |||||
| const char* name() const override { | |||||
| return m_name; | |||||
| } | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||||
| return m_cudnn_enum; | |||||
| } | |||||
| bool is_cudnn() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||||
| bool is_cudnn() const override { | |||||
| return true; | |||||
| } | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final | |||||
| : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| const char* name() const override { | |||||
| return "INPLACE_MATMUL"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return false; | |||||
| } | |||||
| const char* name() const override { return "INPLACE_MATMUL"; } | |||||
| bool is_reproducible() const override { return false; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||||
| }; | }; | ||||
| class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| const char* name() const override { return "CHANNEL_WISE"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| }; | }; | ||||
| //! implement group conv by another algo | //! implement group conv by another algo | ||||
| class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||||
| AlgoBase *m_impl; | |||||
| class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final | |||||
| : public AlgoBase { | |||||
| AlgoBase* m_impl; | |||||
| std::string m_name; | std::string m_name; | ||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { | |||||
| return m_impl->is_reproducible(); | |||||
| } | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
| static void modify_size_args(SizeArgs &args, | |||||
| TensorLayout &src_pg, TensorLayout &diff_pg); | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||||
| TensorLayout& diff_pg); | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DBackwardFilterImpl::AlgoPack { | |||||
| class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||||
| // defined in cudnn.cpp | // defined in cudnn.cpp | ||||
| void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| public: | |||||
| AlgoPack(); | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoInplaceMatmul inplace_matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| AlgoInplaceMatmul inplace_matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoBase*> | |||||
| std::vector<AlgoBase*> | |||||
| //! all algorithms | //! all algorithms | ||||
| all_algos, | all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos; | non_cudnn_algos; | ||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec( | |||||
| } | } | ||||
| void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn.push_back({REPROD, \ | |||||
| #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \ | |||||
| CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||||
| #pragma message \ | |||||
| "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) { | |||||
| cudnn.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||||
| non_cudnn_algos.push_back(&a1x1x1); | non_cudnn_algos.push_back(&a1x1x1); | ||||
| all_algos.push_back(&chanwise); | all_algos.push_back(&chanwise); | ||||
| fill_cudnn_algos(); | fill_cudnn_algos(); | ||||
| for (auto &&i: cudnn) { | for (auto &&i: cudnn) { | ||||
| all_algos.push_back(&i); | |||||
| all_algos.push_back(&i); | |||||
| } | } | ||||
| all_algos.push_back(&inplace_matmul); | all_algos.push_back(&inplace_matmul); | ||||
| all_algos.push_back(&a1x1x1); | |||||
| all_algos.push_back(&a1x1x1); | |||||
| all_algos.reserve(all_algos.size() * 2); | all_algos.reserve(all_algos.size() * 2); | ||||
| // add gconv algos by AlgoGroupConvGeneral | // add gconv algos by AlgoGroupConvGeneral | ||||
| @@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&i); | all_algos.push_back(&i); | ||||
| } | } | ||||
| megdnn_assert(all_algos_data == all_algos.data()); | megdnn_assert(all_algos_data == all_algos.data()); | ||||
| non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||||
| non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 | non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl) | |||||
| Convolution3DForwardImpl::AlgoCUDNN* | Convolution3DForwardImpl::AlgoCUDNN* | ||||
| Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( | Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( | ||||
| cudnnConvolutionFwdAlgo_t algo) { | cudnnConvolutionFwdAlgo_t algo) { | ||||
| @@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " | "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " | ||||
| "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | ||||
| src_layout->to_string().c_str(), | src_layout->to_string().c_str(), | ||||
| fm.group, fm.ocpg, fm.icpg, | |||||
| fm.group, fm.ocpg, fm.icpg, | |||||
| fm.spatial[0], fm.spatial[1], fm.spatial[2], | fm.spatial[0], fm.spatial[1], fm.spatial[2], | ||||
| dst_layout->to_string().c_str(), | dst_layout->to_string().c_str(), | ||||
| fm.padding[0], fm.padding[1], fm.padding[2], | fm.padding[0], fm.padding[1], fm.padding[2], | ||||
| @@ -6,17 +6,20 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/utils.h" | |||||
| #include "src/cuda/convolution3d/helper.h" | #include "src/cuda/convolution3d/helper.h" | ||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/convolution3d/opr_impl.h" | #include "src/cuda/convolution3d/opr_impl.h" | ||||
| #include "src/common/utils.h" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| @@ -29,195 +32,189 @@ namespace cuda { | |||||
| * All the algo impls should try to support non-contiguous batch dim, for group | * All the algo impls should try to support non-contiguous batch dim, for group | ||||
| * conv execution. | * conv execution. | ||||
| */ | */ | ||||
| class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs: public convolution3d::ForwardSizeArgs { | |||||
| Convolution3DForwardImpl *opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNForwardDescs &desc) const { | |||||
| desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||||
| } | |||||
| SizeArgs(Convolution3DForwardImpl *opr, | |||||
| const TensorLayout &src, | |||||
| const TensorLayout &filter, | |||||
| const TensorLayout &dst); | |||||
| SizeArgs(Convolution3DForwardImpl *opr, | |||||
| const TensorLayout &src, | |||||
| const CanonizedFilterMeta &filter, | |||||
| const TensorLayout &dst); | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DForwardImpl *opr, | |||||
| _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||||
| virtual void exec(const ExecArgs &args) const = 0; | |||||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert(req <= workspace.size, | |||||
| "conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { | |||||
| return false; | |||||
| } | |||||
| class Convolution3DForwardImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_1X1X1, | |||||
| CUDA_GROUP_CONV_GENERAL, | |||||
| CUDA_CUDNN, | |||||
| CUDA_INPLACE_MATMUL, | |||||
| CUDA_CHANWISE, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
| struct SizeArgs : public convolution3d::ForwardSizeArgs { | |||||
| Convolution3DForwardImpl* opr; | |||||
| std::string to_string() const; | |||||
| void init_desc(convolution3d::CUDNNForwardDescs& desc) const { | |||||
| desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||||
| } | |||||
| SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||||
| const TensorLayout& filter, const TensorLayout& dst); | |||||
| SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||||
| const CanonizedFilterMeta& filter, const TensorLayout& dst); | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||||
| Workspace workspace; | |||||
| ExecArgs(Convolution3DForwardImpl* opr, _megdnn_tensor_in src, | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert( | |||||
| req <= workspace.size, | |||||
| "conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| virtual bool is_cudnn() const { return false; } | |||||
| }; | }; | ||||
| class Convolution3DForwardImpl::Algo1x1x1 final: public AlgoBase { | |||||
| static void extract_matmul_layouts(const SizeArgs &args, | |||||
| TensorLayout &A, TensorLayout &B, TensorLayout &C); | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| const char* name() const override { | |||||
| return "1x1x1"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase { | |||||
| static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, | |||||
| TensorLayout& B, TensorLayout& C); | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return "1x1x1"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | |||||
| }; | }; | ||||
| //! implement group conv by another algo | //! implement group conv by another algo | ||||
| class Convolution3DForwardImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||||
| AlgoBase *m_impl; | |||||
| class Convolution3DForwardImpl::AlgoGroupConvGeneral final : public AlgoBase { | |||||
| AlgoBase* m_impl; | |||||
| std::string m_name; | std::string m_name; | ||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||||
| public: | |||||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { | |||||
| return m_name.c_str(); | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| bool is_reproducible() const override { | |||||
| return m_impl->is_reproducible(); | |||||
| } | |||||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
| static void modify_size_args(SizeArgs &args, | |||||
| TensorLayout &src_pg, TensorLayout &dst_pg); | |||||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||||
| TensorLayout& dst_pg); | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_impl, ret); | |||||
| return ret; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | ||||
| bool m_is_reproducible; | |||||
| const char *m_name; | |||||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | cudnnConvolutionFwdAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | |||||
| public: | |||||
| public: | |||||
| AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != | |||||
| CudnnAlgoPack::conv3d_fwd_algos().end()); | |||||
| m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum); | |||||
| } | |||||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum): | |||||
| m_is_reproducible(is_reproducible), | |||||
| m_name(name), | |||||
| m_cudnn_enum(cudnn_enum) | |||||
| {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
| bool is_reproducible() const override { | |||||
| return m_is_reproducible; | |||||
| } | |||||
| const char* name() const override { return m_attr.name.c_str(); } | |||||
| const char* name() const override { | |||||
| return m_name; | |||||
| } | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { | |||||
| return m_cudnn_enum; | |||||
| } | |||||
| bool is_cudnn() const override { return true; } | |||||
| bool is_cudnn() const override { | |||||
| return true; | |||||
| } | |||||
| }; | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||||
| class Convolution3DForwardImpl::AlgoInplaceMatmul final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_cudnn_enum, ret); | |||||
| return ret; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "INPLACE_MATMUL"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| }; | }; | ||||
| class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| class Convolution3DForwardImpl::AlgoChanwise final: public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs &args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||||
| void exec(const ExecArgs &args) const override; | |||||
| const char* name() const override { return "INPLACE_MATMUL"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||||
| }; | |||||
| const char* name() const override { | |||||
| return "CHANNEL_WISE"; | |||||
| } | |||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| class Convolution3DForwardImpl::AlgoChanwise final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return "CHANNEL_WISE"; } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||||
| }; | }; | ||||
| class Convolution3DForwardImpl::AlgoPack { | |||||
| class Convolution3DForwardImpl::AlgoPack : NonCopyableObj { | |||||
| // defined in cudnn.cpp | // defined in cudnn.cpp | ||||
| void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| public: | |||||
| AlgoPack(); | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| Algo1x1x1 a1x1x1; | |||||
| AlgoInplaceMatmul inplace_matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoCUDNN> cudnn; | |||||
| Algo1x1x1 a1x1x1; | |||||
| AlgoInplaceMatmul inplace_matmul; | |||||
| AlgoChanwise chanwise; | |||||
| std::vector<AlgoGroupConvGeneral> gconv; | |||||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||||
| std::vector<AlgoBase*> | |||||
| std::vector<AlgoBase*> | |||||
| //! all algorithms | //! all algorithms | ||||
| all_algos, | all_algos, | ||||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | //! non-cudnn algos, used for heuristic if cudnn is not supported | ||||
| non_cudnn_algos; | non_cudnn_algos; | ||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec( | |||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| } | } | ||||
| void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { | void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { | ||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_ALGO(NAME, REPROD) \ | |||||
| cudnn.push_back({ \ | |||||
| REPROD, #NAME \ | |||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||||
| "." V(CUDNN_PATCHLEVEL), \ | |||||
| NAME}) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| #undef DEF_ALGO | |||||
| #undef V | |||||
| #undef V1 | |||||
| for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) { | |||||
| cudnn.push_back(algo.first); | |||||
| } | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -15,126 +16,155 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| class Convolution3DForwardImpl: public Convolution3DForward { | |||||
| public: | |||||
| using Convolution3DForward::Convolution3DForward; | |||||
| void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) override; | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
| const TensorLayout &filter, | |||||
| const TensorLayout &dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const CanonizedFilterMeta& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class Algo1x1x1; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { | |||||
| return sm_algo_pack; | |||||
| } | |||||
| private: | |||||
| static AlgoPack sm_algo_pack; | |||||
| class Convolution3DForwardImpl : public Convolution3DForward { | |||||
| public: | |||||
| using Convolution3DForward::Convolution3DForward; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||||
| const CanonizedFilterMeta& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) { | |||||
| return get_algorithm_heuristic(src, filter, dst, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class Algo1x1x1; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const CanonizedFilterMeta& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||||
| public: | |||||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||||
| void exec(_megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { | |||||
| return sm_algo_pack; | |||||
| } | |||||
| private: | |||||
| static AlgoPack sm_algo_pack; | |||||
| class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||||
| public: | |||||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||||
| void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| bool reproducible) { | |||||
| return get_algorithm_heuristic(filter, diff, grad, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | |||||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||||
| public: | |||||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||||
| void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { | |||||
| return sm_algo_pack; | |||||
| } | |||||
| private: | |||||
| static AlgoPack sm_algo_pack; | |||||
| class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||||
| public: | |||||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) { | |||||
| return get_algorithm_heuristic(src, diff, grad, | |||||
| workspace_limit_in_bytes, reproducible) | |||||
| ->info(); | |||||
| } | |||||
| const char* get_algorithm_set_name() const override; | |||||
| class AlgoBase; | |||||
| class AlgoCUDNN; | |||||
| class AlgoInplaceMatmul; | |||||
| class AlgoChanwise; | |||||
| class AlgoGroupConvGeneral; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible); | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -433,6 +433,137 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||||
| desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | ||||
| } | } | ||||
| ////////////////////////// CudnnAlgoPack ////////////////////////// | |||||
| #define V1(v) #v | |||||
| #define V(v) V1(v) | |||||
| #define DEF_NAME(NAME) \ | |||||
| #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||||
| #define DEF_ALGO(NAME, PROD) \ | |||||
| { \ | |||||
| NAME, { DEF_NAME(NAME), PROD } \ | |||||
| } | |||||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||||
| #pragma message "not latest cudnn" | |||||
| #endif | |||||
| const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv_bwd_data_algos() { | |||||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
| #if CUDNN_MAJOR >= 5 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), | |||||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, | |||||
| true), | |||||
| #endif | |||||
| #endif | |||||
| }; | |||||
| return algos; | |||||
| } | |||||
| const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv_bwd_flt_algos() { | |||||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
| #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | |||||
| true), | |||||
| #if CUDNN_MAJOR >= 6 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), | |||||
| #endif | |||||
| #endif | |||||
| }; | |||||
| return algos; | |||||
| } | |||||
| const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv_fwd_algos() { | |||||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
| true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
| #if CUDNN_MAJOR >= 5 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), | |||||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), | |||||
| #endif | |||||
| #endif | |||||
| }; | |||||
| return algos; | |||||
| } | |||||
| const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv3d_bwd_data_algos() { | |||||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
| }; | |||||
| return algos; | |||||
| } // namespace cuda | |||||
| const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv3d_bwd_flt_algos() { | |||||
| #pragma message \ | |||||
| "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
| }; | |||||
| return algos; | |||||
| } | |||||
| const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||||
| CudnnAlgoPack::conv3d_fwd_algos() { | |||||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||||
| CudnnAlgoPack::Attr> | |||||
| algos = { | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
| true), | |||||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
| }; | |||||
| return algos; | |||||
| } | |||||
| #undef DEF_ALGO | |||||
| #undef DEF_NAME | |||||
| #undef V | |||||
| #undef V1 | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -10,6 +10,7 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include <unordered_map> | |||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
| #include "src/cuda/cudnn_with_check.h" | #include "src/cuda/cudnn_with_check.h" | ||||
| @@ -27,7 +28,7 @@ class TensorDesc { | |||||
| public: | public: | ||||
| TensorDesc(); | TensorDesc(); | ||||
| //! default layout is nchw | //! default layout is nchw | ||||
| void set(const TensorLayout& layout, const param::Convolution::Format = | |||||
| void set(const TensorLayout& layout, const param::Convolution::Format = | |||||
| param::Convolution::Format::NCHW); | param::Convolution::Format::NCHW); | ||||
| ~TensorDesc(); | ~TensorDesc(); | ||||
| cudnnTensorDescriptor_t desc; | cudnnTensorDescriptor_t desc; | ||||
| @@ -103,9 +104,52 @@ class Conv3DDesc { | |||||
| cudnnConvolutionDescriptor_t desc; | cudnnConvolutionDescriptor_t desc; | ||||
| }; | }; | ||||
| class CudnnAlgoPack { | |||||
| public: | |||||
| //! algorithm attr | |||||
| struct Attr { | |||||
| std::string name; | |||||
| bool is_reproducible; | |||||
| }; | |||||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||||
| conv_bwd_data_algos(); | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||||
| conv_bwd_flt_algos(); | |||||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||||
| conv_fwd_algos(); | |||||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||||
| conv3d_bwd_data_algos(); | |||||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||||
| conv3d_bwd_flt_algos(); | |||||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||||
| conv3d_fwd_algos(); | |||||
| }; | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| namespace std { | |||||
| #define DEF_HASH(_type) \ | |||||
| template <> \ | |||||
| struct hash<_type> { \ | |||||
| std::size_t operator()(const _type& algo) const { \ | |||||
| return std::hash<uint32_t>()(static_cast<uint32_t>(algo)); \ | |||||
| } \ | |||||
| } | |||||
| DEF_HASH(cudnnConvolutionBwdDataAlgo_t); | |||||
| DEF_HASH(cudnnConvolutionBwdFilterAlgo_t); | |||||
| DEF_HASH(cudnnConvolutionFwdAlgo_t); | |||||
| #undef DEF_HASH | |||||
| } // namespace std | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl; | |||||
| OprImpl::AlgoPack::AlgoPack() { | OprImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&algo_matmul); | all_algos.push_back(&algo_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl) | |||||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | OprImpl::AlgoPack OprImpl::sm_algo_pack; | ||||
| @@ -13,11 +13,15 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/deformable_conv/opr_impl.h" | #include "src/cuda/deformable_conv/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -26,6 +30,10 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_MATMUL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| DeformableConvBackwardDataImpl* opr; | DeformableConvBackwardDataImpl* opr; | ||||
| @@ -107,17 +115,18 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "AlgoMatmul"; } | const char* name() const override { return "AlgoMatmul"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| }; | }; | ||||
| class DeformableConvBackwardDataImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| AlgoMatmul algo_matmul; | AlgoMatmul algo_matmul; | ||||
| //! all algorithms | //! all algorithms | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -20,7 +20,11 @@ using OprImpl = DeformableConvBackwardFilterImpl; | |||||
| OprImpl::AlgoPack::AlgoPack() { | OprImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&algo_matmul); | all_algos.push_back(&algo_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl) | |||||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | OprImpl::AlgoPack OprImpl::sm_algo_pack; | ||||
| @@ -13,11 +13,15 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/deformable_conv/opr_impl.h" | #include "src/cuda/deformable_conv/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -26,6 +30,11 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_MATMUL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| DeformableConvBackwardFilterImpl* opr; | DeformableConvBackwardFilterImpl* opr; | ||||
| @@ -97,18 +106,18 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "AlgoMatmul"; } | const char* name() const override { return "AlgoMatmul"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| }; | }; | ||||
| class DeformableConvBackwardFilterImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| AlgoMatmul algo_matmul; | AlgoMatmul algo_matmul; | ||||
| //! all algorithms | //! all algorithms | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -22,8 +22,14 @@ using OprImpl = DeformableConvForwardImpl; | |||||
| OprImpl::AlgoPack::AlgoPack() { | OprImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&algo_matmul); | all_algos.push_back(&algo_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl) | |||||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | OprImpl::AlgoPack OprImpl::sm_algo_pack; | ||||
| OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, | OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, | ||||
| @@ -13,9 +13,13 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/cuda/deformable_conv/opr_impl.h" | #include "src/cuda/deformable_conv/opr_impl.h" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -24,6 +28,11 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_MATMUL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| DeformableConvForwardImpl* opr; | DeformableConvForwardImpl* opr; | ||||
| @@ -92,17 +101,17 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "AlgoMatmul"; } | const char* name() const override { return "AlgoMatmul"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||||
| }; | }; | ||||
| class DeformableConvForwardImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class DeformableConvForwardImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| AlgoMatmul algo_matmul; | AlgoMatmul algo_matmul; | ||||
| //! all algorithms | //! all algorithms | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -29,19 +30,6 @@ public: | |||||
| const TensorLayout& mask, | const TensorLayout& mask, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | Algorithm* get_algorithm_heuristic(const TensorLayout& im, | ||||
| const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& offset, | const TensorLayout& offset, | ||||
| @@ -58,31 +46,35 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| class DeformableConvBackwardFilterImpl: public DeformableConvBackwardFilter { | |||||
| class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter { | |||||
| public: | public: | ||||
| using DeformableConvBackwardFilter::DeformableConvBackwardFilter; | using DeformableConvBackwardFilter::DeformableConvBackwardFilter; | ||||
| void exec(_megdnn_tensor_in im,_megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||||
| _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, | |||||
| void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | |||||
| _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | |||||
| _megdnn_tensor_out filter_grad, | |||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | Algorithm* get_algorithm_heuristic(const TensorLayout& im, | ||||
| const TensorLayout& offset, | const TensorLayout& offset, | ||||
| const TensorLayout& mask, | const TensorLayout& mask, | ||||
| @@ -91,9 +83,11 @@ public: | |||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| bool reproducible); | bool reproducible); | ||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& im, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -103,6 +97,21 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& offset, | |||||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -118,19 +127,6 @@ public: | |||||
| _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, | _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& im, const CanonizedFilterMeta& filter, | const TensorLayout& im, const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| @@ -138,11 +134,14 @@ public: | |||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
| size_t workspace_limit_in_bytes, bool reproducible); | size_t workspace_limit_in_bytes, bool reproducible); | ||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& im, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, | |||||
| const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, | |||||
| const TensorLayout& mask_grad) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -152,6 +151,22 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, | |||||
| const TensorLayout& mask_grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||||
| LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { | LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&implicit_gemm); | all_algos.push_back(&implicit_gemm); | ||||
| all_algos.push_back(&batched_matmul); | all_algos.push_back(&batched_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardDataImpl) | |||||
| LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; | LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; | ||||
| LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
| @@ -13,10 +13,14 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -25,6 +29,13 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_IMPLICIT_GEMM, | |||||
| CUDA_BATCHED_MATMUL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| LocalShareBackwardDataImpl* opr; | LocalShareBackwardDataImpl* opr; | ||||
| @@ -77,6 +88,7 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "LOCAL_SHARE_IMPLICIT_GEMM"; | return "LOCAL_SHARE_IMPLICIT_GEMM"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||||
| }; | }; | ||||
| class LocalShareBackwardDataImpl::AlgoBatchedMatMul final | class LocalShareBackwardDataImpl::AlgoBatchedMatMul final | ||||
| @@ -93,11 +105,11 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "LOCAL_SHARE_BATCHED_MATMUL"; | return "LOCAL_SHARE_BATCHED_MATMUL"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||||
| }; | }; | ||||
| class LocalShareBackwardDataImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class LocalShareBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -106,6 +118,7 @@ public: | |||||
| AlgoBatchedMatMul batched_matmul; | AlgoBatchedMatMul batched_matmul; | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||||
| LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { | LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { | ||||
| all_algos.push_back(&implicit_gemm); | all_algos.push_back(&implicit_gemm); | ||||
| all_algos.push_back(&batched_matmul); | all_algos.push_back(&batched_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardFilterImpl) | |||||
| LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; | LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; | ||||
| LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | ||||
| @@ -13,10 +13,14 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -25,6 +29,12 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_IMPLICIT_GEMM, | |||||
| CUDA_BATCHED_MATMUL, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| LocalShareBackwardFilterImpl* opr; | LocalShareBackwardFilterImpl* opr; | ||||
| @@ -75,6 +85,7 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||||
| }; | }; | ||||
| class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { | class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { | ||||
| @@ -88,11 +99,11 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||||
| }; | }; | ||||
| class LocalShareBackwardFilterImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class LocalShareBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -101,6 +112,8 @@ public: | |||||
| AlgoBatchedMatMul batched_matmul; | AlgoBatchedMatMul batched_matmul; | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -19,8 +19,14 @@ LocalShareForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&batch_size_aware_chwn_small_image); | all_algos.push_back(&batch_size_aware_chwn_small_image); | ||||
| all_algos.push_back(&batch_size_aware_chwn); | all_algos.push_back(&batch_size_aware_chwn); | ||||
| all_algos.push_back(&batched_matmul); | all_algos.push_back(&batched_matmul); | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl) | |||||
| LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; | LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; | ||||
| LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, | LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, | ||||
| @@ -14,9 +14,13 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
| #include <unordered_map> | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -25,6 +29,13 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_CHWN_BATCH_SIZE_AWARE, | |||||
| CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE, | |||||
| CUDA_BATCHED_MATMUL | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| LocalShareForwardImpl* opr; | LocalShareForwardImpl* opr; | ||||
| @@ -79,6 +90,7 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE) | |||||
| }; | }; | ||||
| class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final | class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final | ||||
| @@ -95,6 +107,7 @@ public: | |||||
| const char* name() const override { | const char* name() const override { | ||||
| return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE) | |||||
| }; | }; | ||||
| class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { | class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { | ||||
| @@ -108,11 +121,11 @@ public: | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||||
| }; | }; | ||||
| class LocalShareForwardImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class LocalShareForwardImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -122,6 +135,7 @@ public: | |||||
| AlgoBatchedMatMul batched_matmul; | AlgoBatchedMatMul batched_matmul; | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -23,14 +23,6 @@ public: | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| class AlgoBase; | class AlgoBase; | ||||
| @@ -41,7 +33,17 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -54,14 +56,6 @@ public: | |||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| class AlgoBase; | class AlgoBase; | ||||
| @@ -71,6 +65,17 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -84,14 +89,6 @@ public: | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| class AlgoBase; | class AlgoBase; | ||||
| @@ -101,6 +98,17 @@ public: | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -11,6 +11,7 @@ | |||||
| #include "./algos.h" | #include "./algos.h" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include <cuda.h> | #include <cuda.h> | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| @@ -33,10 +34,16 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | ||||
| all_algos.push_back(cublas_bfloat16.get()); | all_algos.push_back(cublas_bfloat16.get()); | ||||
| #endif | #endif | ||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | ||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||||
| MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | ||||
| const TensorLayout& A, | const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| @@ -67,4 +74,5 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | m, k, k, n, m, n, param.transposeA, param.transposeB, | ||||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,14 +6,18 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/matrix_mul/opr_impl.h" | #include "src/cuda/matrix_mul/opr_impl.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include <unordered_map> | |||||
| #include <cuda.h> | #include <cuda.h> | ||||
| #include <memory> | #include <memory> | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| @@ -32,6 +36,15 @@ protected: | |||||
| ~AlgoBase() = default; | ~AlgoBase() = default; | ||||
| public: | public: | ||||
| enum class AlgoType : uint32_t { | |||||
| CUDA_CUBLAS, | |||||
| CUDA_WMMA_UINT4X4X32, | |||||
| CUDA_CUBLASLT, | |||||
| CUDA_NAIVE, | |||||
| CUDA_BFLOAT16 | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
| struct SizeArgs { | struct SizeArgs { | ||||
| MatrixMulForwardImpl* opr; | MatrixMulForwardImpl* opr; | ||||
| @@ -62,12 +75,12 @@ public: | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | ||||
| virtual void exec(const ExecArgs& args) const = 0; | virtual void exec(const ExecArgs& args) const = 0; | ||||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
| } | } | ||||
| bool is_available_reproducible( | bool is_available_reproducible( | ||||
| const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||||
| return (!reproducible || is_reproducible()) && | return (!reproducible || is_reproducible()) && | ||||
| is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
| } | } | ||||
| @@ -80,8 +93,6 @@ public: | |||||
| name(), req, workspace.size); | name(), req, workspace.size); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| }; | }; | ||||
| class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { | class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { | ||||
| @@ -91,13 +102,10 @@ public: | |||||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | ||||
| return 0_z; | return 0_z; | ||||
| } | } | ||||
| const char* name() const override { | |||||
| return "CUBLAS"; | |||||
| } | |||||
| const char* name() const override { return "CUBLAS"; } | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||||
| }; | }; | ||||
| #if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
| @@ -106,13 +114,10 @@ public: | |||||
| AlgoUInt4x4x32WMMA() = default; | AlgoUInt4x4x32WMMA() = default; | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| const char* name() const override { | |||||
| return "UINT4x4x32_WMMA"; | |||||
| } | |||||
| const char* name() const override { return "UINT4x4x32_WMMA"; } | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| @@ -120,13 +125,10 @@ class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase { | |||||
| public: | public: | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
| const char* name() const override { | |||||
| return "CUBLAS_LT"; | |||||
| } | |||||
| const char* name() const override { return "CUBLAS_LT"; } | |||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { | |||||
| return true; | |||||
| } | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| @@ -140,6 +142,7 @@ public: | |||||
| const char* name() const override { return "NAIVE"; } | const char* name() const override { return "NAIVE"; } | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||||
| }; | }; | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| @@ -151,6 +154,13 @@ public: | |||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_algorithm, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | ||||
| @@ -160,9 +170,9 @@ private: | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| class MatrixMulForwardImpl::AlgoPack { | |||||
| AlgoPack(const AlgoPack&) = delete; | |||||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack(); | AlgoPack(); | ||||
| @@ -178,6 +188,8 @@ public: | |||||
| std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | ||||
| #endif | #endif | ||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -82,7 +82,7 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||||
| args.opr->handle()->create_operator<MatrixMulForward>(); | args.opr->handle()->create_operator<MatrixMulForward>(); | ||||
| matmul_opr->param() = args.opr->param(); | matmul_opr->param() = args.opr->param(); | ||||
| matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | ||||
| matmul_opr->execution_policy() = {m_algorithm}; | |||||
| matmul_opr->execution_policy() = {m_algorithm->info()}; | |||||
| matmul_opr->exec(a, b, c, ctypecvt.workspace()); | matmul_opr->exec(a, b, c, ctypecvt.workspace()); | ||||
| } | } | ||||
| ctypecvt.comp_to_dst_type(c, args.tensor_c); | ctypecvt.comp_to_dst_type(c, args.tensor_c); | ||||
| @@ -25,15 +25,6 @@ public: | |||||
| bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
| return "CUDA MATMUL"; | return "CUDA MATMUL"; | ||||
| } | } | ||||
| @@ -55,6 +46,17 @@ public: | |||||
| static const AlgoPack& algo_pack() { | static const AlgoPack& algo_pack() { | ||||
| return sm_algo_pack; | return sm_algo_pack; | ||||
| } | } | ||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, | |||||
| bool reproducible) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -10,10 +10,14 @@ | |||||
| */ | */ | ||||
| #include "src/fallback/conv_bias/algos.h" | #include "src/fallback/conv_bias/algos.h" | ||||
| #include "src/fallback/conv_bias/conv1x1/algos.h" | |||||
| #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||||
| #include "src/fallback/conv_bias/im2col/algos.h" | |||||
| #include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/winograd/strategy.h" | #include "src/fallback/conv_bias/winograd/strategy.h" | ||||
| #include "src/naive/convolution/helper.h" | #include "src/naive/convolution/helper.h" | ||||
| #include "src/common/algo_base.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -176,6 +180,7 @@ void kern_default(const ConvBiasImpl::NCBKernParam& p) { | |||||
| } // namespace | } // namespace | ||||
| MIDOUT_DECL(megdnn_fallback_naive) | MIDOUT_DECL(megdnn_fallback_naive) | ||||
| /* ======================= AlgoNaive ======================== */ | /* ======================= AlgoNaive ======================== */ | ||||
| bool ConvBiasImpl::AlgoNaive::usable( | bool ConvBiasImpl::AlgoNaive::usable( | ||||
| @@ -36,6 +36,7 @@ public: | |||||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | ||||
| return {support_data_type, AlgoCategory::NAIVE}; | return {support_data_type, AlgoCategory::NAIVE}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | ||||
| @@ -59,6 +60,12 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
| @@ -87,6 +94,12 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
| @@ -115,6 +128,12 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
| @@ -143,6 +162,12 @@ public: | |||||
| ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
| @@ -155,6 +155,12 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
| const NCBKernSizeParam& param) const override; \ | const NCBKernSizeParam& param) const override; \ | ||||
| ConvAlgoTypePack get_algo_type() const override { \ | ConvAlgoTypePack get_algo_type() const override { \ | ||||
| return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | ||||
| } \ | |||||
| std::string param() const override { \ | |||||
| std::string ret; \ | |||||
| serialize_write_pod(m_matmul_algo, ret); \ | |||||
| serialize_write_pod(m_tile_size, ret); \ | |||||
| return ret; \ | |||||
| } \ | } \ | ||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| @@ -60,6 +60,13 @@ public: | |||||
| return {m_matmul_algo->matmul_description().algo_type.data_type, | return {m_matmul_algo->matmul_description().algo_type.data_type, | ||||
| AlgoCategory::IM2COL}; | AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| serialize_write_pod(m_oc_block_size, ret); | |||||
| return ret; | |||||
| } | |||||
| protected: | protected: | ||||
| size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | ||||
| @@ -43,6 +43,7 @@ public: | |||||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | ||||
| return {support_data_type, AlgoCategory::IM2COL}; | return {support_data_type, AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1_GEMV) | |||||
| protected: | protected: | ||||
| size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | ||||
| @@ -68,6 +68,14 @@ public: | |||||
| return {m_matmul_algo->matmul_description().algo_type.data_type, | return {m_matmul_algo->matmul_description().algo_type.data_type, | ||||
| AlgoCategory::IM2COL}; | AlgoCategory::IM2COL}; | ||||
| } | } | ||||
| MEGDNN_DECL_ALGO_TYPE(FB_IM2COL) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_matmul_algo, ret); | |||||
| serialize_write_pod(m_ohw_tile_size, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | private: | ||||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
| @@ -22,6 +22,14 @@ | |||||
| #include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/opr_impl.h" | |||||
| #elif MEGDNN_AARCH64 | |||||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||||
| #elif MEGDNN_ARMV7 | |||||
| #include "src/armv7/conv_bias/opr_impl.h" | |||||
| #endif | |||||
| #include <cstring> | #include <cstring> | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -65,17 +73,19 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { | |||||
| class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
| AlgoNaive algo_naive; | AlgoNaive algo_naive; | ||||
| SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
| SmallVector<AlgoBase*> m_all_algos; | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| refhold.emplace_back(new AlgoConv1x1Gemv()); | refhold.emplace_back(new AlgoConv1x1Gemv()); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| static CpuOprDelegationStorage<> storage; | static CpuOprDelegationStorage<> storage; | ||||
| auto matmul_opr = storage.get<MatrixMul>(); | auto matmul_opr = storage.get<MatrixMul>(); | ||||
| auto&& matmul_algos = | |||||
| static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | |||||
| auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||||
| ->get_all_packed_algo(); | |||||
| for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
| #if MEGDNN_X86 | #if MEGDNN_X86 | ||||
| //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | ||||
| @@ -97,13 +107,13 @@ public: | |||||
| refhold.emplace_back(new AlgoIm2col( | refhold.emplace_back(new AlgoIm2col( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo), | static_cast<MatrixMulImpl::AlgoBase*>(algo), | ||||
| ohw_tile_size)); | ohw_tile_size)); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| for (size_t oc_tile_size : {48, 24}) { | for (size_t oc_tile_size : {48, 24}) { | ||||
| refhold.emplace_back(new AlgoConv1x1( | refhold.emplace_back(new AlgoConv1x1( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo), | static_cast<MatrixMulImpl::AlgoBase*>(algo), | ||||
| oc_tile_size)); | oc_tile_size)); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -113,26 +123,35 @@ public: | |||||
| //! FIXME: I do not know a better way to do it. | //! FIXME: I do not know a better way to do it. | ||||
| refhold.emplace_back(new AlgoWinogradF32( | refhold.emplace_back(new AlgoWinogradF32( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | static_cast<MatrixMulImpl::AlgoBase*>(algo))); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoWinogradF32_4x4( | refhold.emplace_back(new AlgoWinogradF32_4x4( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | static_cast<MatrixMulImpl::AlgoBase*>(algo))); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoWinogradQS8( | refhold.emplace_back(new AlgoWinogradQS8( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | static_cast<MatrixMulImpl::AlgoBase*>(algo))); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| refhold.emplace_back(new AlgoWinogradQS8_8x8( | refhold.emplace_back(new AlgoWinogradQS8_8x8( | ||||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | static_cast<MatrixMulImpl::AlgoBase*>(algo))); | ||||
| all_algos.emplace_back(refhold.back().get()); | |||||
| m_all_algos.emplace_back(refhold.back().get()); | |||||
| #endif | #endif | ||||
| } | } | ||||
| all_algos.emplace_back(&algo_naive); | |||||
| m_all_algos.emplace_back(&algo_naive); | |||||
| for (auto&& algo : m_all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | } | ||||
| SmallVector<AlgoBase*> all_algos; | |||||
| const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; } | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | }; | ||||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack sl_algo_pack; | |||||
| return sl_algo_pack.all_algos; | |||||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
| static AlgoPack algo_pack; | |||||
| return algo_pack; | |||||
| } | |||||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||||
| return algo_pack().all_algos(); | |||||
| } | } | ||||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | ||||
| @@ -140,7 +159,7 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||||
| megdnn_assert(nr_type_contain(target_type.data_type), | megdnn_assert(nr_type_contain(target_type.data_type), | ||||
| "ConvBias algo selection only support one type"); | "ConvBias algo selection only support one type"); | ||||
| SmallVector<ConvBiasImpl::AlgoBase*> algos; | SmallVector<ConvBiasImpl::AlgoBase*> algos; | ||||
| for (auto&& algo : algo_pack()) { | |||||
| for (auto&& algo : get_all_packed_algo()) { | |||||
| auto algo_type = algo->get_algo_type(); | auto algo_type = algo->get_algo_type(); | ||||
| if (contain_data_type(algo_type.data_type, target_type.data_type) && | if (contain_data_type(algo_type.data_type, target_type.data_type) && | ||||
| algo_type.algo_category == target_type.algo_category) { | algo_type.algo_category == target_type.algo_category) { | ||||
| @@ -166,7 +185,7 @@ void ConvBiasImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| workspace.size, preprocessed_filter); | workspace.size, preprocessed_filter); | ||||
| auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | ||||
| preprocessed_filter); | preprocessed_filter); | ||||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam, workspace.size); | |||||
| auto&& algo = get_algorithm(fparam, workspace.size); | |||||
| if (!is_naive_algo(algo) && | if (!is_naive_algo(algo) && | ||||
| NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { | NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { | ||||
| exec_with_ncb_kern(fparam, algo); | exec_with_ncb_kern(fparam, algo); | ||||
| @@ -189,9 +208,10 @@ void ConvBiasImpl::exec_preprocess(const TensorLayout& src_layout, | |||||
| auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | ||||
| preprocessed_filter); | preprocessed_filter); | ||||
| //! should not pass workspace_size limit otherwise can not find match algo | //! should not pass workspace_size limit otherwise can not find match algo | ||||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||||
| if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo, | |||||
| fparam) <= workspace.size) { | |||||
| auto&& algo = get_algorithm(fparam); | |||||
| if (!is_naive_algo(algo) && | |||||
| NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= | |||||
| workspace.size) { | |||||
| exec_preprocess_with_ncb_kern(fparam, algo); | exec_preprocess_with_ncb_kern(fparam, algo); | ||||
| } else { | } else { | ||||
| naive::ConvBiasForwardImpl::exec_preprocess( | naive::ConvBiasForwardImpl::exec_preprocess( | ||||
| @@ -207,7 +227,7 @@ size_t ConvBiasImpl::get_workspace_in_bytes( | |||||
| const PreprocessedFilter* preprocessed_filter) { | const PreprocessedFilter* preprocessed_filter) { | ||||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, | ||||
| preprocessed_filter); | preprocessed_filter); | ||||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||||
| auto&& algo = get_algorithm(fparam); | |||||
| if (is_naive_algo(algo)) { | if (is_naive_algo(algo)) { | ||||
| return naive::ConvBiasForwardImpl::get_workspace_in_bytes( | return naive::ConvBiasForwardImpl::get_workspace_in_bytes( | ||||
| src, filter, bias, z, dst, preprocessed_filter); | src, filter, bias, z, dst, preprocessed_filter); | ||||
| @@ -221,7 +241,7 @@ size_t ConvBiasImpl::get_preprocess_workspace_in_bytes( | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
| Algorithm* algo = get_algorithm(fparam); | |||||
| auto&& algo = get_algorithm(fparam); | |||||
| if (is_naive_algo(algo)) { | if (is_naive_algo(algo)) { | ||||
| return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( | return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( | ||||
| src, filter, bias, z, dst); | src, filter, bias, z, dst); | ||||
| @@ -235,7 +255,7 @@ SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout( | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
| Algorithm* algo = get_algorithm(fparam); | |||||
| auto&& algo = get_algorithm(fparam); | |||||
| if (is_naive_algo(algo)) { | if (is_naive_algo(algo)) { | ||||
| return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( | return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( | ||||
| src, filter, bias, z, dst); | src, filter, bias, z, dst); | ||||
| @@ -443,7 +463,7 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||||
| MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
| std::vector<Algorithm*> algos; | std::vector<Algorithm*> algos; | ||||
| std::vector<Algorithm*> prefer_algos; | std::vector<Algorithm*> prefer_algos; | ||||
| for (auto&& algo : algo_pack()) { | |||||
| for (auto&& algo : get_all_packed_algo()) { | |||||
| if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { | if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { | ||||
| if (algo->is_preferred(param)) { | if (algo->is_preferred(param)) { | ||||
| prefer_algos.push_back(algo); | prefer_algos.push_back(algo); | ||||
| @@ -457,10 +477,49 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||||
| return algos; | return algos; | ||||
| } | } | ||||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||||
| const AlgorithmDesc& desc) const { | |||||
| if (!desc.valid()) { | |||||
| return nullptr; | |||||
| } else { | |||||
| switch (desc.handle_type) { | |||||
| case Handle::HandleType::FALLBACK: { | |||||
| const auto& map = algo_pack().all_algos_map(); | |||||
| megdnn_assert(map.find(desc) != map.end()); | |||||
| return map.at(desc); | |||||
| }; | |||||
| #if MEGDNN_X86 | |||||
| case Handle::HandleType::X86: | |||||
| return x86::ConvBiasImpl::get_algo_from_desc(desc); | |||||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| case Handle::HandleType::ARM_COMMON: | |||||
| return arm_common::ConvBiasImpl::get_algo_from_desc(desc); | |||||
| #if MEGDNN_AARCH64 | |||||
| case Handle::HandleType::AARCH64: | |||||
| return aarch64::ConvBiasImpl::get_algo_from_desc(desc); | |||||
| #else | |||||
| case Handle::HandleType::ARMV7: | |||||
| return armv7::ConvBiasImpl::get_algo_from_desc(desc); | |||||
| #endif | |||||
| #endif | |||||
| case Handle::HandleType::NAIVE: { | |||||
| auto algo = static_cast<naive::HandleImpl*>(handle()) | |||||
| ->default_conv_bias_fwd_algo(); | |||||
| megdnn_assert(algo->info().desc == desc); | |||||
| return algo; | |||||
| } | |||||
| default: | |||||
| megdnn_throw("Unknown handle type"); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | ||||
| const NCBKernSizeParam& param, size_t workspace_size) { | const NCBKernSizeParam& param, size_t workspace_size) { | ||||
| if (auto set = execution_policy().algorithm) { | |||||
| return set; | |||||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||||
| return algo; | |||||
| } | } | ||||
| if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
| memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
| @@ -216,6 +216,86 @@ public: | |||||
| AlgoBase() : Algorithm() { | AlgoBase() : Algorithm() { | ||||
| m_handle_type = Handle::HandleType::FALLBACK; | m_handle_type = Handle::HandleType::FALLBACK; | ||||
| } | } | ||||
| enum class AlgoType : uint32_t { | |||||
| //! fallback | |||||
| FB_NAIVE = 1 << 0, | |||||
| FB_WINOGRAD_F32, | |||||
| FB_WINOGRAD_4X4_F32, | |||||
| FB_WINOGRAD_QS8, | |||||
| FB_WINOGRAD_8X8_QS8, | |||||
| FB_CONV1x1, | |||||
| FB_CONV1x1_GEMV, | |||||
| FB_IM2COL, | |||||
| #if MEGDNN_X86 | |||||
| X86_DIRECT = 1 << 8, | |||||
| X86_DIRECT_STRD2, | |||||
| X86_WINOGRAD_F63_8x8_F32, | |||||
| X86_WINOGRAD_F23_8x8_F32, | |||||
| X86_MKLDNN, | |||||
| X86_CHANWISE_AVX2_STRD1_QINT8, | |||||
| X86_CHANWISE_AVX2_STRD2_QINT8, | |||||
| X86_DIRECT_AVX2_STRD1_INT8, | |||||
| X86_DIRECT_AVX2_STRD2_INT8, | |||||
| X86_MKLDNN_QINT8, | |||||
| X86_MKLDNN_MATMUL_QINT8, | |||||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8, | |||||
| ARM_COMMON_WINOGRAD_F45_FP16, | |||||
| ARM_COMMON_WINOGRAD_F63_FP16, | |||||
| ARM_COMMON_WINOGRAD_F23_8X8_FP16, | |||||
| ARM_COMMON_DIRECT_FP16, | |||||
| ARM_COMMON_DIRECT_STRD1_FP16, | |||||
| ARM_COMMON_WINOGRAD_F23_4X4_FP32, | |||||
| ARM_COMMON_WINOGRAD_F63_FP32, | |||||
| ARM_COMMON_WINOGRAD_F63_4X4_FP32, | |||||
| ARM_COMMON_WINOGRAD_F54_FP32, | |||||
| ARM_COMMON_WINOGRAD_F45_FP32, | |||||
| ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||||
| ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||||
| ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||||
| ARM_COMMON_DIRECT_FP32, | |||||
| ARM_COMMON_DIRECT_STRD1_FP32, | |||||
| ARM_COMMON_DIRECT_STRD2_FP32, | |||||
| ARM_COMMON_DIRECT_NCHW44_FP32, | |||||
| ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||||
| ARM_COMMON_CHWNWISE_NCHW44_F32, | |||||
| ARM_COMMON_DIRECT_STRD1_S8, | |||||
| ARM_COMMON_DIRECT_STRD2_S8, | |||||
| ARM_COMMON_DIRECT_NCHW44, | |||||
| ARM_COMMON_DIRECT_NCHW_NCHW44_S8, | |||||
| ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | |||||
| ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | |||||
| ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | |||||
| ARM_COMMON_DIRECT_STRD1_DOT_S8, | |||||
| ARM_COMMON_DIRECT_STRD2_DOT_S8, | |||||
| ARM_COMMON_DIRECT_NCHW44_DOT_S8, | |||||
| ARM_COMMON_WINOGRAD_F23_8X8_S8, | |||||
| ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32, | |||||
| ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8, | |||||
| ARM_COMMON_DIRECT_INT8X8X16, | |||||
| ARM_COMMON_DIRECT_NCHW44_INT8X8X16, | |||||
| ARM_COMMON_DIRECT_STRD2_INT8X8X16, | |||||
| ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16, | |||||
| ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16, | |||||
| ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16, | |||||
| ARM_COMMON_DIRECT_STRD1_QU8, | |||||
| ARM_COMMON_DIRECT_STRD2_QU8, | |||||
| ARM_COMMON_DIRECT_STRD1_DOT_QU8, | |||||
| ARM_COMMON_DIRECT_STRD2_DOT_QU8, | |||||
| #if MEGDNN_AARCH64 | |||||
| AARCH64_DIRECT_STRD2_FP16, | |||||
| AARCH64_DIRECT_STRD2_FP32, | |||||
| AARCH64_MATMUL_S8, | |||||
| AARCH64_MATMUL_QU8, | |||||
| #else | |||||
| ARMV7_MATMUL_S8, | |||||
| ARMV7_MATMUL_QU8, | |||||
| #endif // MEGDNN_AARCH64 | |||||
| #endif | |||||
| }; | |||||
| virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
| virtual bool usable( | virtual bool usable( | ||||
| const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
| @@ -255,12 +335,14 @@ public: | |||||
| //! get the type of the algo | //! get the type of the algo | ||||
| virtual ConvAlgoTypePack get_algo_type() const = 0; | virtual ConvAlgoTypePack get_algo_type() const = 0; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| }; | }; | ||||
| using AlgoMapper = AlgoBase::Mapper; | |||||
| /** | /** | ||||
| * \brief get all the algorithm for the opr. | * \brief get all the algorithm for the opr. | ||||
| */ | */ | ||||
| virtual SmallVector<AlgoBase*> algo_pack(); | |||||
| virtual SmallVector<AlgoBase*> get_all_packed_algo(); | |||||
| /** | /** | ||||
| * \brief select algo according to input algo type | * \brief select algo according to input algo type | ||||
| @@ -305,6 +387,8 @@ private: | |||||
| bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | ||||
| Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||||
| //! get algorithm set by user or by heuristic | //! get algorithm set by user or by heuristic | ||||
| Algorithm* get_algorithm( | Algorithm* get_algorithm( | ||||
| const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
| @@ -320,6 +404,8 @@ private: | |||||
| _megdnn_tensor_in bias, _megdnn_tensor_out dst, | _megdnn_tensor_in bias, _megdnn_tensor_out dst, | ||||
| _megdnn_workspace workspace, | _megdnn_workspace workspace, | ||||
| const PreprocessedFilter* preprocessed_filter); | const PreprocessedFilter* preprocessed_filter); | ||||
| static const AlgoPack& algo_pack(); | |||||
| }; | }; | ||||
| inline bool is_enable_filter_preprocess( | inline bool is_enable_filter_preprocess( | ||||