GitOrigin-RevId: 479718ac75
tags/v1.2.0
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t { | |||
| /*! | |||
| * \brief Abstract representation of an algorithm for implementing | |||
| * the operator | |||
| * | |||
| * All pointers to Algorithm should be allocated globally and usable | |||
| * across multiple megdnn handles, and they should not be freed by | |||
| * the caller. | |||
| */ | |||
| class Algorithm { | |||
| public: | |||
| static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1); | |||
| /** | |||
| * \brief Algorithm information, we can get real algo from | |||
| * AlgorithmInfo::Info::Desc | |||
| */ | |||
| struct Info { | |||
| struct Desc { | |||
| //! backend of the algo belonging to | |||
| Handle::HandleType handle_type; | |||
| //! indicate the real algo implementation | |||
| uint32_t type = INVALID_ALGO_TYPE; | |||
| //! serialized param of the algo type | |||
| std::string param; | |||
| bool valid() const { return type != INVALID_ALGO_TYPE; } | |||
| void reset() { type = INVALID_ALGO_TYPE; } | |||
| bool operator==(const Desc& rhs) const { | |||
| return handle_type == rhs.handle_type && type == rhs.type && | |||
| param == rhs.param; | |||
| } | |||
| } desc; | |||
| //! algorithm name | |||
| std::string name; | |||
| bool is_reproducible; | |||
| bool valid() const { return desc.valid(); } | |||
| void reset() { desc.reset(); } | |||
| //! desc donate the algo | |||
| bool operator==(const Info& rhs) const { return desc == rhs.desc; } | |||
| }; | |||
| virtual ~Algorithm() = default; | |||
| /** | |||
| * \brief whether the execution result is | |||
| * reproducible across multiple runs. | |||
| */ | |||
| virtual bool is_reproducible() const = 0; | |||
| virtual const char* name() const = 0; | |||
| //! serialized param | |||
| virtual std::string param() const { return {}; } | |||
| virtual uint32_t type() const = 0; | |||
| Handle::HandleType handle_type() const { return m_handle_type; } | |||
| Info info() const { | |||
| return {{handle_type(), type(), param()}, name(), is_reproducible()}; | |||
| } | |||
| template <typename T> | |||
| static void serialize_write_pod(const T& val, std::string& result) { | |||
| result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | |||
| } | |||
| static void serialize_write_pod(const char* val, std::string& result) { | |||
| result.append(val, strlen(val)); | |||
| } | |||
| template <typename T> | |||
| static T deserialize_read_pod(const std::string& data, size_t offset = 0) { | |||
| T ret = *reinterpret_cast<const T*>(&data[offset]); | |||
| return ret; | |||
| } | |||
| protected: | |||
| ~Algorithm() = default; | |||
| Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | |||
| }; | |||
| @@ -127,6 +176,8 @@ class MultiAlgoOpr; | |||
| template <class Opr> | |||
| class MultiAlgoOpr<Opr, -1> { | |||
| public: | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| using AlgorithmDesc = detail::Algorithm::Info::Desc; | |||
| using Algorithm = detail::Algorithm; | |||
| /*! | |||
| * \brief get a string representation for current algorithm set; | |||
| @@ -139,8 +190,8 @@ public: | |||
| //! policy for executing the operator | |||
| struct ExecutionPolicy { | |||
| //! nullptr means using heuristic | |||
| Algorithm* algorithm = nullptr; | |||
| //! INVALID_ALGO_TYPE algo_type means using heuristic | |||
| AlgorithmInfo algo; | |||
| }; | |||
| ExecutionPolicy& execution_policy() { return m_execution_policy; } | |||
| @@ -161,6 +212,39 @@ template <class Opr> | |||
| class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | |||
| public: | |||
| using Algorithm = detail::Algorithm; | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| * | |||
| * The selected algorithm should not use workspace more than | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) { | |||
| return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||
| reproducible) | |||
| ->info(); | |||
| } | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -179,9 +263,6 @@ public: | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) = 0; | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| }; | |||
| //! specializae for nargs == 4 | |||
| @@ -189,6 +270,40 @@ template <class Opr> | |||
| class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | |||
| public: | |||
| using Algorithm = detail::Algorithm; | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| * | |||
| * The selected algorithm should not use workspace more than | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||
| reproducible) | |||
| ->info(); | |||
| } | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -207,9 +322,6 @@ public: | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) = 0; | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| }; | |||
| //! specializae for nargs == 5 | |||
| @@ -217,6 +329,42 @@ template <class Opr> | |||
| class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | |||
| public: | |||
| using Algorithm = detail::Algorithm; | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| * | |||
| * The selected algorithm should not use workspace more than | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -237,9 +385,6 @@ public: | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) = 0; | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| }; | |||
| //! specializae for nargs == 8 | |||
| @@ -247,6 +392,42 @@ template <class Opr> | |||
| class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | |||
| public: | |||
| using Algorithm = detail::Algorithm; | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| * | |||
| * The selected algorithm should not use workspace more than | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -269,9 +450,6 @@ public: | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| bool reproducible = false) = 0; | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| }; | |||
| } // namespace detail | |||
| } // namespace megdnn | |||
| @@ -31,6 +31,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16) | |||
| }; | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -36,6 +36,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32) | |||
| }; | |||
| } // namespace aarch64 | |||
| @@ -48,6 +48,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8) | |||
| }; | |||
| } // namespace aarch64 | |||
| @@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| AlgoF16DirectStride2 f16_direct_stride2; | |||
| #endif | |||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_matmul_algos; | |||
| public: | |||
| AlgoPack() { | |||
| matmul_algos.emplace_back(&qu8_matrix_mul); | |||
| matmul_algos.emplace_back(&s8_matrix_mul); | |||
| m_matmul_algos.emplace_back(&qu8_matrix_mul); | |||
| m_matmul_algos.emplace_back(&s8_matrix_mul); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| direct_algos.emplace_back(&f16_direct_stride2); | |||
| m_direct_algos.emplace_back(&f16_direct_stride2); | |||
| #endif | |||
| direct_algos.emplace_back(&f32_direct_stride2); | |||
| m_direct_algos.emplace_back(&f32_direct_stride2); | |||
| for (auto&& algo : m_direct_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| for (auto&& algo : m_matmul_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||
| return m_direct_algos; | |||
| } | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||
| const { | |||
| return m_matmul_algos; | |||
| } | |||
| SmallVector<AlgoBase*> direct_algos; | |||
| SmallVector<AlgoBase*> matmul_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
| static AlgoPack sl_algo_pack; | |||
| auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
| algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
| sl_algo_pack.direct_algos.end()); | |||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
| ConvBiasImpl::get_all_packed_algo() { | |||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
| algo_pack().direct_algos().end()); | |||
| //! We put matmul algos at the begin. Because matmul will get privilege when | |||
| //! prefer return true. See | |||
| algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), | |||
| sl_algo_pack.matmul_algos.end()); | |||
| algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||
| algo_pack().matmul_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -25,7 +25,9 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
| protected: | |||
| const char* get_algorithm_set_name() const override; | |||
| @@ -38,6 +40,7 @@ private: | |||
| class AlgoF16DirectStride2; | |||
| #endif | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace aarch64 | |||
| @@ -48,6 +48,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8) | |||
| }; | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -27,6 +27,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
| @@ -37,6 +38,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||
| @@ -47,6 +49,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||
| @@ -58,10 +61,17 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32Gemv final | |||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
| public: | |||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
| m_handle_type = Handle::HandleType::AARCH64; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV) | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||
| @@ -72,6 +82,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||
| @@ -83,6 +94,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) | |||
| }; | |||
| #endif | |||
| @@ -98,6 +110,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
| @@ -110,6 +123,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) | |||
| }; | |||
| #else | |||
| @@ -124,6 +138,7 @@ public: | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||
| @@ -136,6 +151,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||
| @@ -147,6 +163,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) | |||
| }; | |||
| #endif | |||
| @@ -160,6 +177,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||
| @@ -171,6 +189,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
| @@ -186,6 +205,7 @@ public: | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
| @@ -201,6 +221,7 @@ public: | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
| @@ -214,6 +235,7 @@ public: | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
| @@ -225,6 +247,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||
| @@ -236,6 +259,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| @@ -249,6 +273,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) | |||
| }; | |||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
| @@ -262,6 +287,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) | |||
| }; | |||
| #else | |||
| @@ -273,6 +299,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) | |||
| }; | |||
| #endif | |||
| @@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoQuint8K8x8x8 quint8_k8x8x8; | |||
| #endif | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&f32_gemv); | |||
| all_algos.emplace_back(&f32K8x12x1); | |||
| all_algos.emplace_back(&f32_mk4_8x12x1); | |||
| all_algos.emplace_back(&f32k4x16x1); | |||
| all_algos.emplace_back(&f32mk4_4x16); | |||
| m_all_algos.emplace_back(&f32_gemv); | |||
| m_all_algos.emplace_back(&f32K8x12x1); | |||
| m_all_algos.emplace_back(&f32_mk4_8x12x1); | |||
| m_all_algos.emplace_back(&f32k4x16x1); | |||
| m_all_algos.emplace_back(&f32mk4_4x16); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| all_algos.emplace_back(&f16_k8x24x1); | |||
| all_algos.emplace_back(&f16_mk8_8x8); | |||
| m_all_algos.emplace_back(&f16_k8x24x1); | |||
| m_all_algos.emplace_back(&f16_mk8_8x8); | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
| all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
| m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
| #else | |||
| all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
| all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
| m_all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
| m_all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
| m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
| all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
| m_all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
| m_all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
| m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
| m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
| m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
| all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
| all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
| m_all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
| m_all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
| #if __ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&quint8_gemv_dotprod); | |||
| all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||
| m_all_algos.emplace_back(&quint8_gemv_dotprod); | |||
| m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||
| #else | |||
| all_algos.emplace_back(&quint8_k8x8x8); | |||
| m_all_algos.emplace_back(&quint8_k8x8x8); | |||
| #endif | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
| return m_all_algos; | |||
| } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
| static AlgoPack s_algo_pack; | |||
| auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | |||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
| s_algo_pack.all_algos.end()); | |||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
| MatrixMulImpl::get_all_packed_algo() { | |||
| auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
| algo_pack().all_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -25,7 +25,10 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
| override; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
| private: | |||
| class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||
| @@ -66,6 +69,8 @@ private: | |||
| class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
| class AlgoPack; | |||
| public: | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace aarch64 | |||
| @@ -30,6 +30,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | |||
| @@ -45,7 +46,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | |||
| public: | |||
| @@ -61,6 +62,7 @@ public: | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | |||
| public: | |||
| @@ -75,6 +77,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) | |||
| }; | |||
| class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||
| @@ -94,6 +97,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override{ | |||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) | |||
| }; | |||
| class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||
| @@ -110,6 +114,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -30,6 +30,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||
| @@ -45,6 +46,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||
| @@ -60,6 +62,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||
| @@ -75,6 +78,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||
| @@ -90,6 +94,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | |||
| }; | |||
| //===================== NCHW44 Winograd Support =====================// | |||
| @@ -107,6 +112,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | |||
| @@ -123,6 +129,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||
| }; | |||
| class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | |||
| @@ -139,6 +146,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||
| }; | |||
| // ================================================================= // | |||
| @@ -157,6 +165,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
| @@ -174,6 +183,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
| @@ -191,6 +201,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | |||
| @@ -209,6 +220,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | |||
| @@ -227,6 +239,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||
| }; | |||
| class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | |||
| @@ -244,6 +257,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -33,6 +33,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | |||
| @@ -49,6 +50,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||
| @@ -65,6 +67,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | |||
| @@ -81,6 +84,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | |||
| @@ -95,6 +99,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | |||
| @@ -109,6 +114,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| @@ -126,6 +132,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
| @@ -142,6 +149,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | |||
| @@ -159,6 +167,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8) | |||
| }; | |||
| class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | |||
| @@ -180,6 +189,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8) | |||
| }; | |||
| #endif | |||
| @@ -196,6 +206,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) | |||
| }; | |||
| //=======================input int8 compute fp32 output int8============ | |||
| @@ -213,6 +224,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) | |||
| }; | |||
| //=======================input int8 compute int16 output int8============ | |||
| @@ -231,6 +243,7 @@ public: | |||
| } | |||
| MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -39,6 +39,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | |||
| @@ -54,6 +55,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16) | |||
| }; | |||
| class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||
| @@ -80,6 +82,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16) | |||
| }; | |||
| class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | |||
| @@ -96,12 +99,16 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16) | |||
| }; | |||
| class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | |||
| class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final | |||
| : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } | |||
| const char* name() const override { | |||
| return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; | |||
| } | |||
| bool usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace( | |||
| @@ -111,6 +118,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16) | |||
| }; | |||
| class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | |||
| @@ -129,6 +137,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16) | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| #endif | |||
| SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos; | |||
| public: | |||
| AlgoPack() { | |||
| #if __ARM_FEATURE_DOTPROD | |||
| direct_algos.emplace_back(&ds8_direct_stride1); | |||
| direct_algos.emplace_back(&ds8_direct_stride2); | |||
| direct_algos.emplace_back(&du8_direct_stride1); | |||
| direct_algos.emplace_back(&du8_direct_stride2); | |||
| m_direct_algos.emplace_back(&ds8_direct_stride1); | |||
| m_direct_algos.emplace_back(&ds8_direct_stride2); | |||
| m_direct_algos.emplace_back(&du8_direct_stride1); | |||
| m_direct_algos.emplace_back(&du8_direct_stride2); | |||
| direct_algos.emplace_back(&ds8_direct_nchw44); | |||
| direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||
| m_direct_algos.emplace_back(&ds8_direct_nchw44); | |||
| m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||
| #endif | |||
| direct_algos.emplace_back(&qu8_direct_stride2); | |||
| direct_algos.emplace_back(&qu8_direct_stride1); | |||
| direct_algos.emplace_back(&s8_direct_stride2); | |||
| direct_algos.emplace_back(&s8_direct_nchw44); | |||
| direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||
| direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
| direct_algos.emplace_back(&s8_direct_stride1); | |||
| direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); | |||
| direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
| direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
| m_direct_algos.emplace_back(&qu8_direct_stride2); | |||
| m_direct_algos.emplace_back(&qu8_direct_stride1); | |||
| m_direct_algos.emplace_back(&s8_direct_stride2); | |||
| m_direct_algos.emplace_back(&s8_direct_nchw44); | |||
| m_direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||
| m_direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
| m_direct_algos.emplace_back(&s8_direct_stride1); | |||
| m_direct_algos.emplace_back( | |||
| &s8x8x16_channel_wise_stride1_stride2_nchw44); | |||
| m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
| m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| direct_algos.emplace_back(&f16_direct_stride1); | |||
| direct_algos.emplace_back(&f16_direct); | |||
| m_direct_algos.emplace_back(&f16_direct_stride1); | |||
| m_direct_algos.emplace_back(&f16_direct); | |||
| #endif | |||
| direct_algos.emplace_back(&i8x8x16_direct); | |||
| direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||
| direct_algos.emplace_back(&i8x8x16_stride2); | |||
| direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||
| m_direct_algos.emplace_back(&i8x8x16_direct); | |||
| m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||
| m_direct_algos.emplace_back(&i8x8x16_stride2); | |||
| m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||
| direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
| direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
| direct_algos.emplace_back(&f32_direct_nchw44); | |||
| m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
| m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
| m_direct_algos.emplace_back(&f32_direct_nchw44); | |||
| direct_algos.emplace_back(&f32_direct_stride1); | |||
| direct_algos.emplace_back(&f32_direct_stride2); | |||
| direct_algos.emplace_back(&f32_direct); | |||
| m_direct_algos.emplace_back(&f32_direct_stride1); | |||
| m_direct_algos.emplace_back(&f32_direct_stride2); | |||
| m_direct_algos.emplace_back(&f32_direct); | |||
| static CpuOprDelegationStorage<2> storage; | |||
| auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
| @@ -143,31 +147,31 @@ public: | |||
| refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| //! uncomment this when low precision mode is done | |||
| #if 0 | |||
| refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| #endif | |||
| //! Qint8x8x32 winograd compute with fp32 | |||
| refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| } | |||
| matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
| @@ -180,15 +184,15 @@ public: | |||
| refhold.emplace_back(new AlgoFP32WinogradF63( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP32WinogradF54( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP32WinogradF45( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| } | |||
| @@ -203,15 +207,15 @@ public: | |||
| refhold.emplace_back(new AlgoFP16WinogradF23( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP16WinogradF45( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoFP16WinogradF63( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| } | |||
| matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
| @@ -224,7 +228,7 @@ public: | |||
| refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -238,25 +242,48 @@ public: | |||
| refhold.emplace_back(new AlgoS8WinogradF23_8x8( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | |||
| static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
| tile_size)); | |||
| winograd_algos.emplace_back(refhold.back().get()); | |||
| m_winograd_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| } | |||
| for (auto&& algo : m_direct_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| for (auto&& algo : m_winograd_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| SmallVector<AlgoBase*> direct_algos; | |||
| SmallVector<AlgoBase*> winograd_algos; | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() | |||
| const { | |||
| return m_direct_algos; | |||
| } | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos() | |||
| const { | |||
| return m_winograd_algos; | |||
| } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
| static AlgoPack sl_algo_pack; | |||
| auto&& algos = fallback::ConvBiasImpl::algo_pack(); | |||
| algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
| sl_algo_pack.direct_algos.end()); | |||
| algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), | |||
| sl_algo_pack.winograd_algos.end()); | |||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
| ConvBiasImpl::get_all_packed_algo() { | |||
| auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
| algo_pack().direct_algos().end()); | |||
| algos.insert(algos.end(), algo_pack().winograd_algos().begin(), | |||
| algo_pack().winograd_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| #include "src/common/algo_base.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| @@ -27,7 +28,7 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
| bool is_matmul_quantized_prefer( | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | |||
| @@ -35,7 +36,8 @@ public: | |||
| SmallVector<AlgoCategory> suggest_algo_category_order( | |||
| const NCBKernSizeParam& param) const override; | |||
| class AlgoPack; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
| protected: | |||
| const char* get_algorithm_set_name() const override; | |||
| @@ -95,6 +97,9 @@ private: | |||
| class AlgoF16Direct; | |||
| class AlgoF16DirectStride1; | |||
| #endif | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -32,6 +32,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8) | |||
| }; | |||
| class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | |||
| @@ -48,6 +49,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | |||
| @@ -65,6 +67,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||
| }; | |||
| class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | |||
| @@ -81,6 +84,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||
| }; | |||
| #endif | |||
| } // namespace arm_common | |||
| @@ -36,6 +36,7 @@ public: | |||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | |||
| @@ -54,6 +55,7 @@ public: | |||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32) | |||
| }; | |||
| #endif | |||
| @@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
| FH >= PH + 1 && FW >= PW + 1; | |||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||
| return avaiable && | |||
| ((FH == 2 && OC <= 8) || | |||
| ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); | |||
| @@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
| FH >= PH + 1 && FW >= PW + 1; | |||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||
| return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || | |||
| (FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); | |||
| } | |||
| @@ -23,15 +23,54 @@ using namespace arm_common; | |||
| /* ===================== ConvolutionBackwardData ===================== */ | |||
| struct ConvolutionBackwardDataImpl::AlgoPack { | |||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| #if __ARM_FEATURE_DOTPROD | |||
| AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | |||
| AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | |||
| AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | |||
| AlgoUdot8DirectStride2 quint8_direct_stride2_udot; | |||
| #endif | |||
| fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map; | |||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
| m_all_algos; | |||
| public: | |||
| AlgoPack() { | |||
| #if __ARM_FEATURE_DOTPROD | |||
| m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); | |||
| m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); | |||
| m_all_algos.emplace_back(&quint8_direct_stride1_udot); | |||
| m_all_algos.emplace_back(&quint8_direct_stride2_udot); | |||
| #endif | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| const SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>& | |||
| all_algos() const { | |||
| return m_all_algos; | |||
| } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | |||
| const ConvolutionBackwardDataImpl::AlgoPack& | |||
| ConvolutionBackwardDataImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
| ConvolutionBackwardDataImpl::get_all_packed_algo() { | |||
| auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
| algo_pack().all_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| ConvolutionBackwardDataImpl::ncb_kern_t | |||
| ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||
| @@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | |||
| param); | |||
| } | |||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||
| ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||
| const NCBKernSizeParam& param) { | |||
| auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||
| param); | |||
| #if __ARM_FEATURE_DOTPROD | |||
| if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.filter_type.enumv() == DTypeEnum::Int8) && | |||
| (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
| param.grad_type.enumv() == DTypeEnum::Int32)) { | |||
| if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { | |||
| ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); | |||
| } | |||
| if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { | |||
| ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); | |||
| } | |||
| } else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
| param.grad_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { | |||
| ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); | |||
| } | |||
| if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) { | |||
| ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot); | |||
| } | |||
| } | |||
| #endif | |||
| return ret; | |||
| } | |||
| const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { | |||
| // arm common version 0 | |||
| return "DeconvAC0"; | |||
| @@ -47,11 +47,14 @@ protected: | |||
| size_t ncb_1g_get_workspace(Algorithm* algo, | |||
| const NCBKernSizeParam& param) override; | |||
| std::vector<Algorithm*> ncb_1g_get_all_algorithms( | |||
| const NCBKernSizeParam& param) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
| get_all_packed_algo() override; | |||
| public: | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); | |||
| private: | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class AlgoSdot8DirectStride1; | |||
| @@ -59,8 +62,8 @@ private: | |||
| class AlgoUdot8DirectStride1; | |||
| class AlgoUdot8DirectStride2; | |||
| #endif | |||
| struct AlgoPack; | |||
| static AlgoPack sm_algo_pack; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -36,6 +36,7 @@ public: | |||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | |||
| @@ -55,6 +56,7 @@ public: | |||
| ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||
| }; | |||
| #endif | |||
| } // namespace arm_common | |||
| @@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
| FH >= PH + 1 && FW >= PW + 1; | |||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||
| /** | |||
| * \note In the kernel, we use int32_t to calc the value, in order | |||
| * not generate negative number, we first initialize SHIFT and sub | |||
| @@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
| FH >= PH + 1 && FW >= PW + 1; | |||
| avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||
| param.grad_type.enumv() == DTypeEnum::Int32); | |||
| /** | |||
| * \note In the kernel, we use uint32_t to calc the value, in order | |||
| * not generate negative number, we first initialize SHIFT and sub | |||
| @@ -59,6 +59,7 @@ public: | |||
| virtual bool is_available(const KernParam&) const = 0; | |||
| virtual void exec(const KernParam&) const = 0; | |||
| virtual ~AlgoBase() = default; | |||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -26,6 +26,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
| @@ -39,6 +40,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
| @@ -52,6 +54,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| @@ -66,6 +69,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT) | |||
| }; | |||
| #endif | |||
| @@ -96,6 +100,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -110,6 +115,7 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV) | |||
| }; | |||
| #endif | |||
| @@ -130,6 +136,7 @@ public: | |||
| static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
| static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | |||
| DEFAULT) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM) | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoGevm gevm; | |||
| AlgoF32GemvMK4 f32_gemv_mk4; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&int8x8x16); | |||
| m_all_algos.emplace_back(&int8x8x16); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| all_algos.emplace_back(&f16gemv); | |||
| m_all_algos.emplace_back(&f16gemv); | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x32_gemv); | |||
| all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
| all_algos.emplace_back(&f32_gemv_mk4); | |||
| all_algos.emplace_back(&gevm); | |||
| m_all_algos.emplace_back(&int8x8x32_gemv); | |||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
| m_all_algos.emplace_back(&f32_gemv_mk4); | |||
| m_all_algos.emplace_back(&gevm); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
| return m_all_algos; | |||
| } | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
| MatrixMulImpl::get_all_packed_algo() { | |||
| static AlgoPack s_algo_pack; | |||
| auto&& algos = fallback::MatrixMulImpl::algo_pack(); | |||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
| s_algo_pack.all_algos.end()); | |||
| auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
| algo_pack().all_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -11,6 +11,7 @@ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/matrix_mul/opr_impl.h" | |||
| #include "src/common/algo_base.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| @@ -27,7 +28,10 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
| override; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
| protected: | |||
| class AlgoF32Gemv; // Arm_common F32 Gemv | |||
| @@ -43,6 +47,9 @@ protected: | |||
| #endif | |||
| class AlgoInt8x8x16; // Arm_common Int 8x8x16 | |||
| class AlgoPack; | |||
| public: | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace arm_common | |||
| @@ -10,6 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs/base.h" | |||
| #include "src/fallback/pooling/opr_impl.h" | |||
| namespace megdnn { | |||
| @@ -72,6 +73,8 @@ public: | |||
| virtual ~AlgoBase() = default; | |||
| virtual bool usable(const PoolingKernSizeParam& param) const = 0; | |||
| virtual void exec(const PoolingKernParam& param) const = 0; | |||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
| }; | |||
| private: | |||
| @@ -40,6 +40,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8) | |||
| }; | |||
| } // namespace armv7 | |||
| @@ -24,22 +24,40 @@ using namespace armv7; | |||
| class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| AlgoS8MatrixMul s8_matrix_mul; | |||
| AlgoQU8MatrixMul qu8_matrix_mul; | |||
| fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_algos; | |||
| public: | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&qu8_matrix_mul); | |||
| all_algos.emplace_back(&s8_matrix_mul); | |||
| m_all_algos.emplace_back(&qu8_matrix_mul); | |||
| m_all_algos.emplace_back(&s8_matrix_mul); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_algos() | |||
| const { | |||
| return m_all_algos; | |||
| } | |||
| SmallVector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
| static AlgoPack sl_algo_pack; | |||
| auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
| ConvBiasImpl::get_all_packed_algo() { | |||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
| //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | |||
| //! and nearly equal in aarch64, because of the waste of register in | |||
| //! postprocess | |||
| algos.insert(algos.end(), sl_algo_pack.all_algos.begin(), | |||
| sl_algo_pack.all_algos.end()); | |||
| algos.insert(algos.end(), algo_pack().all_algos().begin(), | |||
| algo_pack().all_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -25,7 +25,9 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
| protected: | |||
| const char* get_algorithm_set_name() const override; | |||
| @@ -34,6 +36,7 @@ private: | |||
| class AlgoS8MatrixMul; | |||
| class AlgoQU8MatrixMul; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace armv7 | |||
| @@ -42,6 +42,7 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8) | |||
| }; | |||
| } // namespace armv7 | |||
| @@ -27,6 +27,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | |||
| @@ -37,6 +38,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | |||
| @@ -48,6 +50,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8) | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -59,6 +62,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | |||
| public: | |||
| @@ -69,6 +73,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) | |||
| }; | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| @@ -80,6 +85,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4) | |||
| }; | |||
| class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | |||
| @@ -90,6 +96,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | |||
| @@ -102,11 +109,18 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD) | |||
| }; | |||
| #endif | |||
| class MatrixMulImpl::AlgoF32Gemv final | |||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
| public: | |||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
| m_handle_type = Handle::HandleType::ARMV7; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | |||
| public: | |||
| @@ -117,6 +131,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | |||
| @@ -128,6 +143,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | |||
| @@ -138,6 +154,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | |||
| @@ -149,6 +166,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | |||
| @@ -160,6 +178,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||
| @@ -171,6 +190,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | |||
| @@ -182,6 +202,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | |||
| @@ -193,6 +214,7 @@ public: | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8) | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
| @@ -204,6 +226,7 @@ public: | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16) | |||
| }; | |||
| } // namespace armv7 | |||
| @@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | |||
| AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&f32_gemv); | |||
| all_algos.emplace_back(&f32); | |||
| all_algos.emplace_back(&f32_mk4_pack_4x12); | |||
| all_algos.emplace_back(&f32_mk4_4x8); | |||
| m_all_algos.emplace_back(&f32_gemv); | |||
| m_all_algos.emplace_back(&f32); | |||
| m_all_algos.emplace_back(&f32_mk4_pack_4x12); | |||
| m_all_algos.emplace_back(&f32_mk4_4x8); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| all_algos.emplace_back(&f16_k4x16x1); | |||
| all_algos.emplace_back(&f16_mk8_4x8); | |||
| m_all_algos.emplace_back(&f16_k4x16x1); | |||
| m_all_algos.emplace_back(&f16_mk8_4x8); | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||
| all_algos.emplace_back(&int8_k6x8x4); | |||
| all_algos.emplace_back(&quint8_k4x8x4); | |||
| m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||
| m_all_algos.emplace_back(&int8_k6x8x4); | |||
| m_all_algos.emplace_back(&quint8_k4x8x4); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
| all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
| all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
| all_algos.emplace_back(&quint8_k4x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||
| all_algos.emplace_back(&int8x8x16_k4x2x16); | |||
| all_algos.emplace_back(&int8x8x16_k4x8x8); | |||
| m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
| m_all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
| m_all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
| m_all_algos.emplace_back(&quint8_k4x8x8); | |||
| m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||
| m_all_algos.emplace_back(&int8x8x16_k4x2x16); | |||
| m_all_algos.emplace_back(&int8x8x16_k4x8x8); | |||
| m_all_algos.emplace_back(&int16x16x32_k12x4x1); | |||
| m_all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| all_algos.emplace_back(&int16x16x32_k12x4x1); | |||
| all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||
| const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
| return m_all_algos; | |||
| } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
| static AlgoPack s_algo_pack; | |||
| auto algos = arm_common::MatrixMulImpl::algo_pack(); | |||
| algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
| s_algo_pack.all_algos.end()); | |||
| const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
| MatrixMulImpl::get_all_packed_algo() { | |||
| auto algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
| algo_pack().all_algos().end()); | |||
| return algos; | |||
| } | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -25,7 +25,10 @@ public: | |||
| } | |||
| }; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
| override; | |||
| MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
| private: | |||
| class AlgoF32; // Armv7 F32 | |||
| @@ -52,6 +55,9 @@ private: | |||
| // DotProduct | |||
| #endif | |||
| class AlgoPack; | |||
| public: | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| } // namespace armv7 | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * \file dnn/src/common/algo_base.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <functional> | |||
| #include <string> | |||
| #include "megdnn/oprs/base.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| #define MEGDNN_DECL_ALGO_TYPE(_type) \ | |||
| uint32_t type() const override { \ | |||
| return static_cast<std::underlying_type<AlgoType>::type>( \ | |||
| AlgoType::_type); \ | |||
| } | |||
| #define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ | |||
| static fallback::_opr::AlgoBase* get_algo_from_desc( \ | |||
| const AlgorithmDesc& desc) | |||
| #define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
| fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \ | |||
| const AlgorithmDesc& desc) { \ | |||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
| algo_pack().all_algos_map().end()); \ | |||
| return algo_pack().all_algos_map().at(desc); \ | |||
| } | |||
| #define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
| _opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
| algo_pack().all_algos_map().end()); \ | |||
| return algo_pack().all_algos_map().at(desc); \ | |||
| } | |||
| /** | |||
| * \brief construct algo from AlgorithmDesc | |||
| */ | |||
| template <typename AlgoBase> | |||
| class AlgoConstructMixin { | |||
| private: | |||
| std::vector<std::unique_ptr<AlgoBase>> m_refhold; | |||
| protected: | |||
| typename AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| //! construct the algo which described by desc, and return the instance | |||
| AlgoBase* construct_and_get_algo( | |||
| const detail::Algorithm::Info::Desc& desc) { | |||
| auto iter = m_all_algos_map.find(desc); | |||
| if (iter != m_all_algos_map.end()) { | |||
| return m_all_algos_map.at(desc); | |||
| } | |||
| std::string serialized_bin; | |||
| AlgoBase::serialize_write_pod(desc.type, serialized_bin); | |||
| serialized_bin += desc.param; | |||
| m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin)); | |||
| m_all_algos_map.emplace(desc, m_refhold.back().get()); | |||
| return m_refhold.back().get(); | |||
| } | |||
| void clear() { | |||
| m_all_algos_map.clear(); | |||
| m_refhold.clear(); | |||
| } | |||
| const typename AlgoBase::Mapper& all_algos_map() const { | |||
| return m_all_algos_map; | |||
| } | |||
| }; | |||
| } // namespace megdnn | |||
| namespace std { | |||
| template <> | |||
| struct hash<megdnn::detail::Algorithm::Info::Desc> { | |||
| std::size_t operator()( | |||
| const megdnn::detail::Algorithm::Info::Desc& desc) const { | |||
| return megdnn::hash_combine<size_t>( | |||
| megdnn::hash_combine<size_t>( | |||
| std::hash<std::string>()(desc.param), | |||
| std::hash<uint32_t>()(desc.type)), | |||
| std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type))); | |||
| } | |||
| }; | |||
| } // namespace std | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -25,15 +25,34 @@ namespace megdnn { | |||
| */ | |||
| template <class Opr, typename... Args> | |||
| typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
| typename Opr::Algorithm* ret; | |||
| if (auto set = opr->execution_policy().algorithm) { | |||
| typename Opr::AlgorithmInfo ret; | |||
| auto set = opr->execution_policy().algo; | |||
| if (set.valid()) { | |||
| ret = set; | |||
| } else { | |||
| ret = opr->get_algorithm_heuristic(std::forward<Args>(args)..., | |||
| std::numeric_limits<size_t>::max(), | |||
| false); | |||
| ret = opr->get_algorithm_info_heuristic( | |||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
| false); | |||
| } | |||
| return opr->get_algo_from_desc(ret.desc); | |||
| } | |||
| /*! | |||
| * \brief get user-configured algorithm, or heuristic algorithm. used in opencl | |||
| * whose algo need to be constructed each time. | |||
| */ | |||
| template <class Opr, typename... Args> | |||
| typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||
| typename Opr::AlgorithmInfo ret; | |||
| auto set = opr->execution_policy().algo; | |||
| if (set.valid()) { | |||
| return opr->algo_pack().construct_and_get_algo(set.desc); | |||
| } else { | |||
| ret = opr->get_algorithm_info_heuristic( | |||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
| false); | |||
| return opr->get_algo_from_desc(ret.desc); | |||
| } | |||
| return static_cast<typename Opr::AlgoBase*>(ret); | |||
| } | |||
| /*! | |||
| @@ -9,6 +9,32 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| /** | |||
| * Boost Software License - Version 1.0 - August 17th, 2003 | |||
| * | |||
| * Permission is hereby granted, free of charge, to any person or organization | |||
| * obtaining a copy of the software and accompanying documentation covered by | |||
| * this license (the "Software") to use, reproduce, display, distribute, | |||
| * execute, and transmit the Software, and to prepare derivative works of the | |||
| * Software, and to permit third-parties to whom the Software is furnished to | |||
| * do so, all subject to the following: | |||
| * | |||
| * The copyright notices in the Software and this entire statement, including | |||
| * the above license grant, this restriction and the following disclaimer, | |||
| * must be included in all copies of the Software, in whole or in part, and | |||
| * all derivative works of the Software, unless such copies or derivative | |||
| * works are solely in the form of machine-executable object code generated by | |||
| * a source language processor. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
| * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT | |||
| * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE | |||
| * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, | |||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |||
| * DEALINGS IN THE SOFTWARE. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/arch.h" | |||
| @@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) { | |||
| return XXHash64CT::hash(str, count, 20160701); | |||
| } | |||
| // refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp | |||
| template <typename T> | |||
| inline T hash_combine(T seed, T value) { | |||
| seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); | |||
| return seed; | |||
| } | |||
| template <typename Vec> | |||
| std::string vec2str(Vec&& vec) { | |||
| std::string res; | |||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||
| BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&int8_nchw4_gemm_dotprod); | |||
| all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl) | |||
| BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; | |||
| BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| @@ -11,13 +11,16 @@ | |||
| #pragma once | |||
| #include <csetjmp> | |||
| #include <unordered_map> | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/batch_conv_bias/opr_impl.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -26,6 +29,12 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_GEMM_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| BatchConvBiasForwardImpl* opr; | |||
| @@ -85,6 +94,7 @@ public: | |||
| const char* name() const override { | |||
| return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8) | |||
| }; | |||
| class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final | |||
| @@ -99,15 +109,16 @@ public: | |||
| const char* name() const override { | |||
| return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8) | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| }; | |||
| class BatchConvBiasForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -116,6 +127,8 @@ public: | |||
| AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -26,6 +26,18 @@ public: | |||
| const TensorLayout& bias, | |||
| const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoInt8NCHW4DotProdGemm; | |||
| class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| @@ -37,15 +49,6 @@ public: | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoInt8NCHW4DotProdGemm; | |||
| class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| for (auto& algo : brute_force_algos) { | |||
| all_algos.push_back(&algo); | |||
| } | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -16,6 +16,8 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
| #include "src/cuda/matrix_mul/cublasLt_wrapper.h" | |||
| #include "src/common/metahelper.h" | |||
| #if CUDA_VERSION >= 10010 | |||
| #include <cublasLt.h> | |||
| #endif | |||
| @@ -28,6 +30,14 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_BRUTE_FORCE, | |||
| CUDA_CUBLAS, | |||
| CUDA_CUBLASLT, | |||
| CUDA_INT8X8X32, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| BatchedMatrixMulForwardImpl* opr; | |||
| @@ -90,6 +100,13 @@ public: | |||
| void exec(const ExecArgs& args) const final; | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algorithm, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoCublas final | |||
| : public BatchedMatrixMulForwardImpl::AlgoBase { | |||
| @@ -100,6 +117,7 @@ public: | |||
| void exec(const ExecArgs& args) const final; | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "CUBLAS"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
| }; | |||
| #if CUDA_VERSION >= 10010 | |||
| class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { | |||
| @@ -110,6 +128,7 @@ public: | |||
| void exec(const ExecArgs& args) const final; | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "CUBLAS_LT"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||
| }; | |||
| #endif | |||
| class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final | |||
| @@ -121,11 +140,13 @@ public: | |||
| void exec(const ExecArgs& args) const final; | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "INT8x8x32"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoPack { | |||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| MatrixMulForwardImpl::AlgoPack mm_pack; | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -137,6 +158,8 @@ public: | |||
| AlgoInt8x8x32 int8x8x32; | |||
| std::vector<AlgoBase*> all_algos; | |||
| std::vector<AlgoBruteForce> brute_force_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | |||
| const SizeArgs& args) const { | |||
| MatrixMulForwardImpl mm{args.opr->handle()}; | |||
| mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | |||
| mm.execution_policy() = {m_algorithm}; | |||
| mm.execution_policy() = {m_algorithm->info()}; | |||
| auto mm_layout_a = args.layout_a.remove_axis(0); | |||
| auto mm_layout_b = args.layout_b.remove_axis(0); | |||
| @@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | |||
| auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
| mm_opr->param() = {args.opr->param().transposeA, | |||
| args.opr->param().transposeB}; | |||
| mm_opr->execution_policy() = {m_algorithm}; | |||
| mm_opr->execution_policy() = {m_algorithm->info()}; | |||
| return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | |||
| args.layout_c); | |||
| @@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||
| auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
| mm_opr->param() = {args.opr->param().transposeA, | |||
| args.opr->param().transposeB}; | |||
| mm_opr->execution_policy() = {m_algorithm}; | |||
| mm_opr->execution_policy() = {m_algorithm->info()}; | |||
| rep(n, N) { | |||
| TensorND A_, B_, C_; | |||
| auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | |||
| @@ -32,6 +32,16 @@ public: | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "BATCHED_MATMUL"; | |||
| } | |||
| bool is_thread_safe() const override { return true; } | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| @@ -40,12 +50,6 @@ public: | |||
| const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "BATCHED_MATMUL"; | |||
| } | |||
| bool is_thread_safe() const override { return true; } | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
| for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | |||
| non_cudnn_algos.push_back(all_algos[i]); | |||
| } | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl) | |||
| ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| ConvBiasForwardImpl* o, const TensorLayout& src, | |||
| const TensorLayout& filter, const TensorLayout& bias, | |||
| @@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| } | |||
| void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn_conv_bias_activations.push_back( \ | |||
| {REPROD, \ | |||
| "CUDNN:ConvBiasActivation:" #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}); \ | |||
| cudnn_convs.push_back( \ | |||
| {REPROD, \ | |||
| "CUDNN:Convolution:" #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||
| #if CUDNN_MAJOR >= 5 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true); | |||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true); | |||
| #endif | |||
| #endif | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | |||
| cudnn_conv_bias_activations.push_back(algo.first); | |||
| cudnn_convs.push_back(algo.first); | |||
| } | |||
| } | |||
| #if CUDA_VERSION >= 10000 | |||
| @@ -6,19 +6,23 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
| #include "src/cuda/conv_bias/helper.h" | |||
| #include "src/cuda/conv_bias/opr_impl.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| #include <cuda.h> | |||
| #include <memory> | |||
| @@ -38,11 +42,39 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_CUDNN_CONVBIAS, | |||
| CUDA_CHANWISE, | |||
| CUDA_CHANWISE_SMALL, | |||
| CUDA_CHANWISE_INT8X8X32, | |||
| CUDA_CUDNN_CONV, | |||
| CUDA_INPLACE_MATMUL, | |||
| CUDA_MATMUL, | |||
| CUDA_MATMUL_INT8X8X32, | |||
| CUDA_1X1, | |||
| CUDA_BATCHED_MATMUL, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_WMMA_UINT4X4X32, | |||
| CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8, | |||
| CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8, | |||
| CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, | |||
| CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, | |||
| CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, | |||
| CUDA_BFLOAT16, | |||
| CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, | |||
| CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | |||
| ConvBiasForwardImpl* opr; | |||
| const PreprocessedFilter* preprocessed_filter; | |||
| std::string to_string() const; | |||
| SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& filter, const TensorLayout& bias, | |||
| @@ -80,13 +112,17 @@ public: | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| virtual size_t get_preprocess_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| MEGDNN_MARK_USED_VAR(args); | |||
| return 0; | |||
| } | |||
| virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
| const SizeArgs& args) const { | |||
| MEGDNN_MARK_USED_VAR(args); | |||
| return {}; | |||
| } | |||
| virtual void exec_preprocess(const ExecArgs& args) const {} | |||
| virtual void exec_preprocess(const ExecArgs& args) const { | |||
| MEGDNN_MARK_USED_VAR(args); | |||
| } | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| @@ -114,11 +150,14 @@ public: | |||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { | |||
| public: | |||
| AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name, | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
| : m_is_reproducible(is_reproducible), | |||
| m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||
| m_cudnn_enum(cudnn_enum) {} | |||
| AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv_fwd_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||
| m_name = ConvBiasForward::algo_name<DefaultParam>( | |||
| "CUDNN:ConvBiasActivation:" + m_attr.name, {}); | |||
| } | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| @@ -127,16 +166,24 @@ public: | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return m_is_reproducible; } | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | |||
| bool is_cudnn() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| bool m_is_reproducible; | |||
| std::string m_name; | |||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | |||
| @@ -154,6 +201,8 @@ public: | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| private: | |||
| mutable std::string m_name; | |||
| }; | |||
| @@ -172,6 +221,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
| private: | |||
| mutable std::string m_name; | |||
| @@ -190,6 +240,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) | |||
| private: | |||
| mutable std::string m_name; | |||
| @@ -197,27 +248,39 @@ private: | |||
| class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { | |||
| public: | |||
| AlgoCUDNNConv(bool is_reproducible, const char* name, | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
| : m_is_reproducible(is_reproducible), | |||
| m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||
| m_cudnn_enum(cudnn_enum) {} | |||
| AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv_fwd_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||
| m_name = ConvBiasForward::algo_name<DefaultParam>( | |||
| "CUDNN:Convolution:" + m_attr.name, {}); | |||
| } | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return m_is_reproducible; } | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| bool is_cudnn() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| bool m_is_reproducible; | |||
| std::string m_name; | |||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| }; | |||
| @@ -237,6 +300,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
| private: | |||
| mutable std::string m_name; | |||
| @@ -261,6 +325,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| @@ -281,6 +346,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) | |||
| private: | |||
| bool need_src_unroll(const SizeArgs& args) const; | |||
| @@ -310,6 +376,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_1X1) | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| @@ -333,6 +400,7 @@ public: | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| @@ -354,6 +422,13 @@ public: | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
| TensorLayout& dst_pg, TensorLayout& bias_pg); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| @@ -370,10 +445,13 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return "QUINT4x4x32_WMMA"; } | |||
| bool is_reproducible() const override { return true; } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| bool use_kernel_fhxfw(const SizeArgs& args) const; | |||
| size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||
| }; | |||
| #endif | |||
| @@ -395,6 +473,7 @@ public: | |||
| const convolution::ConvParam& param, float alpha, float beta, | |||
| float gamma, float scale, cudaStream_t stream, | |||
| param::ConvBias::NonlineMode nonlinear_mode); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||
| @@ -415,8 +494,9 @@ public: | |||
| warp_k == 32 && stage == 2) { | |||
| return ""; | |||
| } | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||
| threadblock_k, warp_m, warp_n, warp_k, stage); | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage); | |||
| } | |||
| }; | |||
| AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | |||
| @@ -433,6 +513,13 @@ public: | |||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
| const SizeArgs& args) const override; | |||
| void exec_preprocess(const ExecArgs& args) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algo_param, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| @@ -457,9 +544,7 @@ public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| template <typename BiasVisitor> | |||
| static void dispatch_nonlinear_mode( | |||
| @@ -471,6 +556,14 @@ public: | |||
| MMATileSize mma_tile_size); | |||
| static std::string to_string(MMATileSize mma_tile_size); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_mma_tile_size, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MMATileSize m_mma_tile_size; | |||
| std::string m_name; | |||
| @@ -488,10 +581,16 @@ public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_mma_tile_size, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| @@ -513,6 +612,13 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_mma_tile_size, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MMATileSize m_mma_tile_size; | |||
| @@ -533,6 +639,13 @@ public: | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_mma_tile_size, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MMATileSize m_mma_tile_size; | |||
| @@ -570,6 +683,13 @@ public: | |||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
| const SizeArgs& args) const override; | |||
| void exec_preprocess(const ExecArgs& args) const override; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algo_param, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| @@ -592,6 +712,14 @@ public: | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | |||
| TensorLayout& fsrc, TensorLayout& ffilter, | |||
| @@ -603,17 +731,16 @@ private: | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoBase*> all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| non_cudnn_algos, bfloat16_algos; | |||
| std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | |||
| std::vector<AlgoCUDNNConv> cudnn_convs; | |||
| AlgoChanwise chanwise; | |||
| @@ -646,6 +773,8 @@ public: | |||
| AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| private: | |||
| #if CUDA_VERSION >= 10000 | |||
| void fill_imma_algos(); | |||
| @@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||
| change_dtype(fdst); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_impl}; | |||
| opr->execution_policy() = {m_impl->info()}; | |||
| return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | |||
| } | |||
| @@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
| convbias_opr->param() = args.opr->param(); | |||
| convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| convbias_opr->execution_policy() = {m_impl}; | |||
| convbias_opr->execution_policy() = {m_impl->info()}; | |||
| convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | |||
| fdst_tensor, nullptr, cvter.workspace()); | |||
| } | |||
| @@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| auto conv_args = args; | |||
| auto cudnn_conv_bias_act_from_enum_wrapper = | |||
| [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
| [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
| return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); | |||
| }; | |||
| auto cudnn_conv_from_enum_wrapper = | |||
| [this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
| [](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
| return sm_algo_pack.cudnn_conv_from_enum(algo); | |||
| }; | |||
| @@ -24,17 +24,6 @@ public: | |||
| _megdnn_tensor_out dst, | |||
| const PreprocessedFilter* preprocessed_filter, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& bias, | |||
| const TensorLayout& z, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&, | |||
| @@ -80,6 +69,20 @@ public: | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& bias, | |||
| const TensorLayout& z, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| @@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(bfloat16_refhold.back().get()); | |||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
| } | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||
| ConvolutionBackwardDataImpl::AlgoCUDNN* | |||
| ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( | |||
| cudnnConvolutionBwdDataAlgo_t algo) { | |||
| @@ -11,8 +11,11 @@ | |||
| #pragma once | |||
| #include "src/cuda/convolution/helper.h" | |||
| #include <unordered_map> | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/convolution/helper.h" | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -23,154 +26,146 @@ namespace cuda { | |||
| * All the algo impls should try to support non-contiguous batch dim, for group | |||
| * conv execution. | |||
| */ | |||
| class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
| ConvolutionBackwardDataImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | |||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
| } | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad); | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
| const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, grad_layout, filter_layout, filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(ConvolutionBackwardDataImpl *opr, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
| virtual void exec(const ExecArgs &args) const = 0; | |||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace( | |||
| const SizeArgs &args, const Workspace &workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd data algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_CUDNN, | |||
| CUDA_MATMUL, | |||
| CUDA_CHANWISE, | |||
| CUDA_CHANWISE_SMALL, | |||
| CUDA_BFLOAT16, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl* handle; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
| ConvolutionBackwardDataImpl* opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution::CUDNNBwdDataDescs& desc) const { | |||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
| } | |||
| virtual bool is_cudnn() const { | |||
| return false; | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, grad_layout, filter_layout, filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd data algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { return false; } | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | |||
| bool m_is_reproducible; | |||
| const char *m_name; | |||
| cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| public: | |||
| public: | |||
| AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv_bwd_data_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum); | |||
| } | |||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||
| m_is_reproducible(is_reproducible), | |||
| m_name(name), | |||
| m_cudnn_enum(cudnn_enum) | |||
| {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| bool is_reproducible() const override { | |||
| return m_is_reproducible; | |||
| } | |||
| const char* name() const override { return m_attr.name.c_str(); } | |||
| const char* name() const override { | |||
| return m_name; | |||
| } | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||
| return m_cudnn_enum; | |||
| } | |||
| bool is_cudnn() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
| bool is_cudnn() const override { | |||
| return true; | |||
| } | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| //! im2col and matmul, with dilation | |||
| class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase { | |||
| template<typename T> | |||
| static void exec_internal(const ExecArgs &args); | |||
| class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase { | |||
| template <typename T> | |||
| static void exec_internal(const ExecArgs& args); | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "MATMUL"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "MATMUL"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "CHANNEL_WISE"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE_SMALL"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "CHANNEL_WISE_SMALL"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||
| @@ -190,61 +185,72 @@ private: | |||
| TensorLayout& fsrc, TensorLayout& ffilter, | |||
| TensorLayout& fdst) const; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algorithm, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| //! implement group conv by another algo | |||
| class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final | |||
| : public AlgoBase { | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { | |||
| return m_impl->is_reproducible(); | |||
| } | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||
| TensorLayout& grad_pg); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| static void modify_size_args(SizeArgs &args, | |||
| TensorLayout &diff_pg, TensorLayout &grad_pg); | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoPack { | |||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoMatmul matmul; | |||
| AlgoChanwise chanwise; | |||
| AlgoChanwiseSmall chanwise_small; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoMatmul matmul; | |||
| AlgoChanwise chanwise; | |||
| AlgoChanwiseSmall chanwise_small; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoBase*> | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| non_cudnn_algos, bfloat16_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||
| change_dtype(fgrad); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_algorithm}; | |||
| opr->execution_policy() = {m_algorithm->info()}; | |||
| return SizeArgs(opr, ffilter, fdiff, fgrad); | |||
| } | |||
| @@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||
| args.handle->create_operator<ConvolutionBackwardData>(); | |||
| conv_back_data_opr->param() = args.opr->param(); | |||
| conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| conv_back_data_opr->execution_policy() = {m_algorithm}; | |||
| conv_back_data_opr->execution_policy() = {m_algorithm->info()}; | |||
| conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | |||
| cvter.workspace()); | |||
| } | |||
| @@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec( | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn.push_back({ \ | |||
| REPROD, #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
| "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||
| #if CUDNN_MAJOR >= 5 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true); | |||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true); | |||
| #endif | |||
| #endif | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) { | |||
| cudnn.push_back(algo.first); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(bfloat16_refhold.back().get()); | |||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
| } | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) | |||
| ConvolutionBackwardFilterImpl::AlgoCUDNN* | |||
| ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( | |||
| cudnnConvolutionBwdFilterAlgo_t algo) { | |||
| @@ -6,13 +6,16 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/convolution/helper.h" | |||
| #include <unordered_map> | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/convolution/helper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -23,141 +26,134 @@ namespace cuda { | |||
| * All the algo impls should try to support non-contiguous batch dim, for group | |||
| * conv execution. | |||
| */ | |||
| class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
| CanonizedFilterMeta grad_filter_meta; | |||
| ConvolutionBackwardFilterImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution::CUDNNBwdFilterDescs &desc) const { | |||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||
| opr->param()); | |||
| } | |||
| SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, src_layout, grad_layout, grad_filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(ConvolutionBackwardFilterImpl *opr, | |||
| _megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
| virtual void exec(const ExecArgs &args) const = 0; | |||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_CUDNN, | |||
| CUDA_MATMUL, | |||
| CUDA_CHANWISE, | |||
| CUDA_BFLOAT16, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl* handle; | |||
| const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
| CanonizedFilterMeta grad_filter_meta; | |||
| ConvolutionBackwardFilterImpl* opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution::CUDNNBwdFilterDescs& desc) const { | |||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||
| } | |||
| AlgoBase& check_workspace( | |||
| const SizeArgs &args, const Workspace &workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd filter algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { | |||
| return false; | |||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, src_layout, grad_layout, grad_filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd filter algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { return false; } | |||
| }; | |||
| class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | |||
| bool m_is_reproducible; | |||
| const char *m_name; | |||
| cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| public: | |||
| public: | |||
| AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv_bwd_flt_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum); | |||
| } | |||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||
| m_is_reproducible(is_reproducible), | |||
| m_name(name), | |||
| m_cudnn_enum(cudnn_enum) | |||
| {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| bool is_reproducible() const override { | |||
| return m_is_reproducible; | |||
| } | |||
| const char* name() const override { return m_attr.name.c_str(); } | |||
| const char* name() const override { | |||
| return m_name; | |||
| } | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||
| return m_cudnn_enum; | |||
| } | |||
| bool is_cudnn() const override { return true; } | |||
| bool is_cudnn() const override { | |||
| return true; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| //! im2col and matmul, with dilation | |||
| class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase { | |||
| template<typename T> | |||
| static void exec_internal(const ExecArgs &args); | |||
| class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase { | |||
| template <typename T> | |||
| static void exec_internal(const ExecArgs& args); | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "MATMUL"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "MATMUL"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| }; | |||
| class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "CHANNEL_WISE"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| }; | |||
| class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | |||
| @@ -169,6 +165,13 @@ public: | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algorithm, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| std::string m_name; | |||
| @@ -180,57 +183,62 @@ private: | |||
| }; | |||
| //! implement group conv by another algo | |||
| class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final | |||
| : public AlgoBase { | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
| TensorLayout& diff_pg); | |||
| bool is_reproducible() const override { | |||
| return m_impl->is_reproducible(); | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| static void modify_size_args(SizeArgs &args, | |||
| TensorLayout &src_pg, TensorLayout &diff_pg); | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class ConvolutionBackwardFilterImpl::AlgoPack { | |||
| class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoMatmul matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoMatmul matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoBase*> | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| non_cudnn_algos, bfloat16_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||
| change_dtype(fgrad); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_algorithm}; | |||
| opr->execution_policy() = {m_algorithm->info()}; | |||
| return SizeArgs(opr, fsrc, fdiff, fgrad); | |||
| } | |||
| @@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||
| conv_back_filter_opr->param() = args.opr->param(); | |||
| conv_back_filter_opr->param().compute_mode = | |||
| Param::ComputeMode::DEFAULT; | |||
| conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||
| conv_back_filter_opr->execution_policy() = {m_algorithm->info()}; | |||
| conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | |||
| cvter.workspace()); | |||
| } | |||
| @@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec( | |||
| } | |||
| void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn.push_back({ \ | |||
| REPROD, #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
| "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||
| #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true); | |||
| #if CUDNN_MAJOR >= 6 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true); | |||
| #endif | |||
| #endif | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { | |||
| cudnn.push_back(algo.first); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, | |||
| conv_param.dilate_w, | |||
| 0, | |||
| conv_param.compute_mode}; | |||
| ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm}; | |||
| ret.convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||
| return ret; | |||
| } | |||
| @@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||
| CUDNNBwdDataDescs desc; | |||
| args.init_desc(desc); | |||
| //disable, segfault in megbrain, need further investigate. | |||
| #if 0 | |||
| bool is_heuristic_success= convolution:: | |||
| PerformanceModelBackwardData::get_algo_backward_data_success( | |||
| args, desc, workspace_limit_in_bytes, &algo); | |||
| if (is_heuristic_success) { | |||
| return sm_algo_pack.cudnn_from_enum(algo); | |||
| } | |||
| #endif | |||
| #if CUDNN_MAJOR >= 7 | |||
| int max_count = 0; | |||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||
| @@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
| const PreprocessedFilter* preprocessed_filter, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| @@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
| TensorLayout bias_layout; | |||
| TensorLayout z_layout; | |||
| }; | |||
| private: | |||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&); | |||
| }; | |||
| class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
| public: | |||
| using ConvolutionBackwardData::ConvolutionBackwardData; | |||
| void exec(_megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoChanwiseSmall; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&); | |||
| }; | |||
| class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| public: | |||
| using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& gradk, | |||
| const CanonizedFilterMeta& grad_meta, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { | |||
| public: | |||
| using ConvolutionBackwardData::ConvolutionBackwardData; | |||
| void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| return get_algorithm_heuristic(filter, filter_meta, diff, grad, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoChanwiseSmall; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { | |||
| public: | |||
| using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| return get_algorithm_heuristic(src, diff, grad, grad_meta, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&i); | |||
| } | |||
| megdnn_assert(all_algos_data == all_algos.data()); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl) | |||
| Convolution3DBackwardDataImpl::AlgoCUDNN* | |||
| Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( | |||
| cudnnConvolutionBwdDataAlgo_t algo) { | |||
| @@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const | |||
| fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
| diff_layout->to_string().c_str(), | |||
| grad_layout->to_string().c_str(), | |||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||
| fm.stride[0], fm.stride[1], fm.stride[2], | |||
| fm.dilation[0], fm.dilation[1] ,fm.dilation[2], | |||
| !fm.should_flip, | |||
| @@ -6,13 +6,16 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/convolution3d/helper.h" | |||
| #include <unordered_map> | |||
| #include "src/cuda/convolution3d/helper.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -23,170 +26,174 @@ namespace cuda { | |||
| * All the algo impls should try to support non-contiguous batch dim, for group | |||
| * conv execution. | |||
| */ | |||
| class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *diff_layout, *grad_layout; | |||
| Convolution3DBackwardDataImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const { | |||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
| } | |||
| SizeArgs(Convolution3DBackwardDataImpl *opr, | |||
| const TensorLayout &filter, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(Convolution3DBackwardDataImpl *opr, | |||
| const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, grad_layout, filter_meta, diff_layout, | |||
| opr->param().data_type}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DBackwardDataImpl *opr, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
| virtual void exec(const ExecArgs &args) const = 0; | |||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace( | |||
| const SizeArgs &args, const Workspace &workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd data algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_CUDNN, | |||
| CUDA_CHANWISE, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl* handle; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *diff_layout, *grad_layout; | |||
| Convolution3DBackwardDataImpl* opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const { | |||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
| } | |||
| virtual bool is_cudnn() const { | |||
| return false; | |||
| SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| SizeArgs(Convolution3DBackwardDataImpl* opr, | |||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad); | |||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, grad_layout, filter_meta, diff_layout, | |||
| opr->param().data_type}; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd data algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { return false; } | |||
| }; | |||
| class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | |||
| bool m_is_reproducible; | |||
| const char *m_name; | |||
| cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| public: | |||
| public: | |||
| AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv3d_bwd_data_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum); | |||
| } | |||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||
| m_is_reproducible(is_reproducible), | |||
| m_name(name), | |||
| m_cudnn_enum(cudnn_enum) | |||
| {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| bool is_reproducible() const override { | |||
| return m_is_reproducible; | |||
| } | |||
| const char* name() const override { return m_attr.name.c_str(); } | |||
| const char* name() const override { | |||
| return m_name; | |||
| } | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||
| return m_cudnn_enum; | |||
| } | |||
| bool is_cudnn() const override { return true; } | |||
| bool is_cudnn() const override { | |||
| return true; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "CHANNEL_WISE"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| }; | |||
| //! implement group conv by another algo | |||
| class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final | |||
| : public AlgoBase { | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { | |||
| return m_impl->is_reproducible(); | |||
| } | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||
| TensorLayout& grad_pg); | |||
| static void modify_size_args(SizeArgs &args, | |||
| TensorLayout &diff_pg, TensorLayout &grad_pg); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class Convolution3DBackwardDataImpl::AlgoPack { | |||
| class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoBase*> | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec( | |||
| } | |||
| void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn.push_back({ \ | |||
| REPROD, #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
| "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) { | |||
| cudnn.push_back(algo.first); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -17,7 +17,7 @@ using namespace cuda; | |||
| Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
| non_cudnn_algos.push_back(&chanwise); | |||
| non_cudnn_algos.push_back(&inplace_matmul); | |||
| non_cudnn_algos.push_back(&inplace_matmul); | |||
| all_algos.push_back(&chanwise); // prefer chanwise | |||
| fill_cudnn_algos(); | |||
| @@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
| } | |||
| megdnn_assert(all_algos_data == all_algos.data()); | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl) | |||
| Convolution3DBackwardFilterImpl::AlgoCUDNN* | |||
| Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( | |||
| cudnnConvolutionBwdFilterAlgo_t algo) { | |||
| @@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { | |||
| "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | |||
| src_layout->to_string().c_str(), | |||
| diff_layout->to_string().c_str(), | |||
| fm.group, fm.ocpg, fm.icpg, | |||
| fm.group, fm.ocpg, fm.icpg, | |||
| fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||
| fm.stride[0], fm.stride[1], fm.stride[2], | |||
| fm.dilation[0], fm.dilation[1], fm.dilation[2], | |||
| !fm.should_flip, | |||
| @@ -6,198 +6,198 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/convolution3d/helper.h" | |||
| #include <unordered_map> | |||
| #include "src/cuda/convolution3d/helper.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| const TensorLayout *src_layout, *diff_layout; | |||
| CanonizedFilterMeta grad_filter_meta; | |||
| Convolution3DBackwardFilterImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const { | |||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||
| opr->param()); | |||
| } | |||
| SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const CanonizedFilterMeta &grad); | |||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, src_layout, grad_filter_meta, diff_layout, | |||
| opr->param().data_type}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DBackwardFilterImpl *opr, | |||
| _megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
| virtual void exec(const ExecArgs &args) const = 0; | |||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd filter algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_CUDNN, | |||
| CUDA_INPLACE_MATMUL, | |||
| CUDA_CHANWISE, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| struct SizeArgs { | |||
| HandleImpl* handle; | |||
| const TensorLayout *src_layout, *diff_layout; | |||
| CanonizedFilterMeta grad_filter_meta; | |||
| Convolution3DBackwardFilterImpl* opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const { | |||
| desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||
| } | |||
| SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& diff, const CanonizedFilterMeta& grad); | |||
| virtual bool is_cudnn() const { | |||
| return false; | |||
| convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, src_layout, grad_filter_meta, diff_layout, | |||
| opr->param().data_type}; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv bwd filter algo %s: " | |||
| "required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { return false; } | |||
| }; | |||
| class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | |||
| bool m_is_reproducible; | |||
| const char *m_name; | |||
| cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| public: | |||
| public: | |||
| AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||
| : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv3d_bwd_flt_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum); | |||
| } | |||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||
| m_is_reproducible(is_reproducible), | |||
| m_name(name), | |||
| m_cudnn_enum(cudnn_enum) | |||
| {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| bool is_reproducible() const override { | |||
| return m_is_reproducible; | |||
| } | |||
| const char* name() const override { return m_attr.name.c_str(); } | |||
| const char* name() const override { | |||
| return m_name; | |||
| } | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||
| return m_cudnn_enum; | |||
| } | |||
| bool is_cudnn() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
| bool is_cudnn() const override { | |||
| return true; | |||
| } | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final | |||
| : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| const char* name() const override { | |||
| return "INPLACE_MATMUL"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return false; | |||
| } | |||
| const char* name() const override { return "INPLACE_MATMUL"; } | |||
| bool is_reproducible() const override { return false; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
| }; | |||
| class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| const char* name() const override { return "CHANNEL_WISE"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| }; | |||
| //! implement group conv by another algo | |||
| class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final | |||
| : public AlgoBase { | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { | |||
| return m_impl->is_reproducible(); | |||
| } | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| static void modify_size_args(SizeArgs &args, | |||
| TensorLayout &src_pg, TensorLayout &diff_pg); | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
| TensorLayout& diff_pg); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class Convolution3DBackwardFilterImpl::AlgoPack { | |||
| class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoInplaceMatmul inplace_matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| AlgoInplaceMatmul inplace_matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoBase*> | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec( | |||
| } | |||
| void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn.push_back({REPROD, \ | |||
| #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \ | |||
| CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||
| #pragma message \ | |||
| "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) { | |||
| cudnn.push_back(algo.first); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||
| non_cudnn_algos.push_back(&a1x1x1); | |||
| all_algos.push_back(&chanwise); | |||
| fill_cudnn_algos(); | |||
| for (auto &&i: cudnn) { | |||
| all_algos.push_back(&i); | |||
| all_algos.push_back(&i); | |||
| } | |||
| all_algos.push_back(&inplace_matmul); | |||
| all_algos.push_back(&a1x1x1); | |||
| all_algos.push_back(&a1x1x1); | |||
| all_algos.reserve(all_algos.size() * 2); | |||
| // add gconv algos by AlgoGroupConvGeneral | |||
| @@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&i); | |||
| } | |||
| megdnn_assert(all_algos_data == all_algos.data()); | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl) | |||
| Convolution3DForwardImpl::AlgoCUDNN* | |||
| Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( | |||
| cudnnConvolutionFwdAlgo_t algo) { | |||
| @@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " | |||
| "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | |||
| src_layout->to_string().c_str(), | |||
| fm.group, fm.ocpg, fm.icpg, | |||
| fm.group, fm.ocpg, fm.icpg, | |||
| fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
| dst_layout->to_string().c_str(), | |||
| fm.padding[0], fm.padding[1], fm.padding[2], | |||
| @@ -6,17 +6,20 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/convolution3d/helper.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/convolution3d/opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include <unordered_map> | |||
| @@ -29,195 +32,189 @@ namespace cuda { | |||
| * All the algo impls should try to support non-contiguous batch dim, for group | |||
| * conv execution. | |||
| */ | |||
| class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs: public convolution3d::ForwardSizeArgs { | |||
| Convolution3DForwardImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNForwardDescs &desc) const { | |||
| desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||
| } | |||
| SizeArgs(Convolution3DForwardImpl *opr, | |||
| const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst); | |||
| SizeArgs(Convolution3DForwardImpl *opr, | |||
| const TensorLayout &src, | |||
| const CanonizedFilterMeta &filter, | |||
| const TensorLayout &dst); | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DForwardImpl *opr, | |||
| _megdnn_tensor_in src, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs &args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
| virtual void exec(const ExecArgs &args) const = 0; | |||
| bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert(req <= workspace.size, | |||
| "conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { | |||
| return false; | |||
| } | |||
| class Convolution3DForwardImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_1X1X1, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_CUDNN, | |||
| CUDA_INPLACE_MATMUL, | |||
| CUDA_CHANWISE, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs : public convolution3d::ForwardSizeArgs { | |||
| Convolution3DForwardImpl* opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution3d::CUDNNForwardDescs& desc) const { | |||
| desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||
| } | |||
| SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||
| const TensorLayout& filter, const TensorLayout& dst); | |||
| SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||
| const CanonizedFilterMeta& filter, const TensorLayout& dst); | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
| Workspace workspace; | |||
| ExecArgs(Convolution3DForwardImpl* opr, _megdnn_tensor_in src, | |||
| _megdnn_tensor_in filter, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert( | |||
| req <= workspace.size, | |||
| "conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| virtual bool is_cudnn() const { return false; } | |||
| }; | |||
| class Convolution3DForwardImpl::Algo1x1x1 final: public AlgoBase { | |||
| static void extract_matmul_layouts(const SizeArgs &args, | |||
| TensorLayout &A, TensorLayout &B, TensorLayout &C); | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| const char* name() const override { | |||
| return "1x1x1"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase { | |||
| static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, | |||
| TensorLayout& B, TensorLayout& C); | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return "1x1x1"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | |||
| }; | |||
| //! implement group conv by another algo | |||
| class Convolution3DForwardImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| class Convolution3DForwardImpl::AlgoGroupConvGeneral final : public AlgoBase { | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase *impl); | |||
| public: | |||
| AlgoGroupConvGeneral(AlgoBase* impl); | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return m_name.c_str(); | |||
| } | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { | |||
| return m_impl->is_reproducible(); | |||
| } | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| static void modify_size_args(SizeArgs &args, | |||
| TensorLayout &src_pg, TensorLayout &dst_pg); | |||
| static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
| TensorLayout& dst_pg); | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_impl, ret); | |||
| return ret; | |||
| } | |||
| }; | |||
| class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | |||
| bool m_is_reproducible; | |||
| const char *m_name; | |||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
| CudnnAlgoPack::Attr m_attr; | |||
| public: | |||
| public: | |||
| AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { | |||
| megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != | |||
| CudnnAlgoPack::conv3d_fwd_algos().end()); | |||
| m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum); | |||
| } | |||
| AlgoCUDNN(bool is_reproducible, const char *name, | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum): | |||
| m_is_reproducible(is_reproducible), | |||
| m_name(name), | |||
| m_cudnn_enum(cudnn_enum) | |||
| {} | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
| bool is_reproducible() const override { | |||
| return m_is_reproducible; | |||
| } | |||
| const char* name() const override { return m_attr.name.c_str(); } | |||
| const char* name() const override { | |||
| return m_name; | |||
| } | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { | |||
| return m_cudnn_enum; | |||
| } | |||
| bool is_cudnn() const override { return true; } | |||
| bool is_cudnn() const override { | |||
| return true; | |||
| } | |||
| }; | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
| class Convolution3DForwardImpl::AlgoInplaceMatmul final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_cudnn_enum, ret); | |||
| return ret; | |||
| } | |||
| const char* name() const override { | |||
| return "INPLACE_MATMUL"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| }; | |||
| class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| class Convolution3DForwardImpl::AlgoChanwise final: public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs &args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
| void exec(const ExecArgs &args) const override; | |||
| const char* name() const override { return "INPLACE_MATMUL"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
| }; | |||
| const char* name() const override { | |||
| return "CHANNEL_WISE"; | |||
| } | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| class Convolution3DForwardImpl::AlgoChanwise final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return "CHANNEL_WISE"; } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
| }; | |||
| class Convolution3DForwardImpl::AlgoPack { | |||
| class Convolution3DForwardImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator = (const AlgoPack &) = delete; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| public: | |||
| AlgoPack(); | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| Algo1x1x1 a1x1x1; | |||
| AlgoInplaceMatmul inplace_matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoCUDNN> cudnn; | |||
| Algo1x1x1 a1x1x1; | |||
| AlgoInplaceMatmul inplace_matmul; | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<AlgoBase*> | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec( | |||
| cudnnGetErrorString(status), args.to_string().c_str()); | |||
| } | |||
| void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_ALGO(NAME, REPROD) \ | |||
| cudnn.push_back({ \ | |||
| REPROD, #NAME \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
| "." V(CUDNN_PATCHLEVEL), \ | |||
| NAME}) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| #undef DEF_ALGO | |||
| #undef V | |||
| #undef V1 | |||
| for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) { | |||
| cudnn.push_back(algo.first); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -15,126 +16,155 @@ | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class Convolution3DForwardImpl: public Convolution3DForward { | |||
| public: | |||
| using Convolution3DForward::Convolution3DForward; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const CanonizedFilterMeta& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class Algo1x1x1; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| class Convolution3DForwardImpl : public Convolution3DForward { | |||
| public: | |||
| using Convolution3DForward::Convolution3DForward; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
| const CanonizedFilterMeta& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| return get_algorithm_heuristic(src, filter, dst, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class Algo1x1x1; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const CanonizedFilterMeta& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||
| public: | |||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||
| void exec(_megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||
| public: | |||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||
| void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| return get_algorithm_heuristic(filter, diff, grad, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||
| public: | |||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const CanonizedFilterMeta& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||
| public: | |||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const CanonizedFilterMeta& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| return get_algorithm_heuristic(src, diff, grad, | |||
| workspace_limit_in_bytes, reproducible) | |||
| ->info(); | |||
| } | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| class AlgoCUDNN; | |||
| class AlgoInplaceMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const CanonizedFilterMeta& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -433,6 +433,137 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||
| desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | |||
| } | |||
| ////////////////////////// CudnnAlgoPack ////////////////////////// | |||
| #define V1(v) #v | |||
| #define V(v) V1(v) | |||
| #define DEF_NAME(NAME) \ | |||
| #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
| #define DEF_ALGO(NAME, PROD) \ | |||
| { \ | |||
| NAME, { DEF_NAME(NAME), PROD } \ | |||
| } | |||
| #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
| #pragma message "not latest cudnn" | |||
| #endif | |||
| const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv_bwd_data_algos() { | |||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||
| #if CUDNN_MAJOR >= 5 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), | |||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, | |||
| true), | |||
| #endif | |||
| #endif | |||
| }; | |||
| return algos; | |||
| } | |||
| const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv_bwd_flt_algos() { | |||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||
| #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | |||
| true), | |||
| #if CUDNN_MAJOR >= 6 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), | |||
| #endif | |||
| #endif | |||
| }; | |||
| return algos; | |||
| } | |||
| const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv_fwd_algos() { | |||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||
| true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||
| #if CUDNN_MAJOR >= 5 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), | |||
| #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), | |||
| #endif | |||
| #endif | |||
| }; | |||
| return algos; | |||
| } | |||
| const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv3d_bwd_data_algos() { | |||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||
| }; | |||
| return algos; | |||
| } // namespace cuda | |||
| const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv3d_bwd_flt_algos() { | |||
| #pragma message \ | |||
| "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||
| }; | |||
| return algos; | |||
| } | |||
| const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||
| CudnnAlgoPack::conv3d_fwd_algos() { | |||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||
| CudnnAlgoPack::Attr> | |||
| algos = { | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||
| true), | |||
| DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||
| }; | |||
| return algos; | |||
| } | |||
| #undef DEF_ALGO | |||
| #undef DEF_NAME | |||
| #undef V | |||
| #undef V1 | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -10,6 +10,7 @@ | |||
| */ | |||
| #pragma once | |||
| #include <unordered_map> | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "src/cuda/cudnn_with_check.h" | |||
| @@ -27,7 +28,7 @@ class TensorDesc { | |||
| public: | |||
| TensorDesc(); | |||
| //! default layout is nchw | |||
| void set(const TensorLayout& layout, const param::Convolution::Format = | |||
| void set(const TensorLayout& layout, const param::Convolution::Format = | |||
| param::Convolution::Format::NCHW); | |||
| ~TensorDesc(); | |||
| cudnnTensorDescriptor_t desc; | |||
| @@ -103,9 +104,52 @@ class Conv3DDesc { | |||
| cudnnConvolutionDescriptor_t desc; | |||
| }; | |||
| class CudnnAlgoPack { | |||
| public: | |||
| //! algorithm attr | |||
| struct Attr { | |||
| std::string name; | |||
| bool is_reproducible; | |||
| }; | |||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||
| conv_bwd_data_algos(); | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||
| conv_bwd_flt_algos(); | |||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||
| conv_fwd_algos(); | |||
| static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||
| conv3d_bwd_data_algos(); | |||
| static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||
| conv3d_bwd_flt_algos(); | |||
| static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||
| conv3d_fwd_algos(); | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| namespace std { | |||
| #define DEF_HASH(_type) \ | |||
| template <> \ | |||
| struct hash<_type> { \ | |||
| std::size_t operator()(const _type& algo) const { \ | |||
| return std::hash<uint32_t>()(static_cast<uint32_t>(algo)); \ | |||
| } \ | |||
| } | |||
| DEF_HASH(cudnnConvolutionBwdDataAlgo_t); | |||
| DEF_HASH(cudnnConvolutionBwdFilterAlgo_t); | |||
| DEF_HASH(cudnnConvolutionFwdAlgo_t); | |||
| #undef DEF_HASH | |||
| } // namespace std | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl; | |||
| OprImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&algo_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl) | |||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
| @@ -13,11 +13,15 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/deformable_conv/opr_impl.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -26,6 +30,10 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_MATMUL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| DeformableConvBackwardDataImpl* opr; | |||
| @@ -107,17 +115,18 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "AlgoMatmul"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| }; | |||
| class DeformableConvBackwardDataImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoMatmul algo_matmul; | |||
| //! all algorithms | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -20,7 +20,11 @@ using OprImpl = DeformableConvBackwardFilterImpl; | |||
| OprImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&algo_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl) | |||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
| @@ -13,11 +13,15 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/deformable_conv/opr_impl.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -26,6 +30,11 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_MATMUL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| DeformableConvBackwardFilterImpl* opr; | |||
| @@ -97,18 +106,18 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "AlgoMatmul"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| }; | |||
| class DeformableConvBackwardFilterImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoMatmul algo_matmul; | |||
| //! all algorithms | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -22,8 +22,14 @@ using OprImpl = DeformableConvForwardImpl; | |||
| OprImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&algo_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl) | |||
| OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
| OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, | |||
| @@ -13,9 +13,13 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/deformable_conv/opr_impl.h" | |||
| #include "src/cuda/utils.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -24,6 +28,11 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_MATMUL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| DeformableConvForwardImpl* opr; | |||
| @@ -92,17 +101,17 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "AlgoMatmul"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
| }; | |||
| class DeformableConvForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class DeformableConvForwardImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoMatmul algo_matmul; | |||
| //! all algorithms | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -29,19 +30,6 @@ public: | |||
| const TensorLayout& mask, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const CanonizedFilterMeta& filter, | |||
| const TensorLayout& offset, | |||
| @@ -58,31 +46,35 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| class DeformableConvBackwardFilterImpl: public DeformableConvBackwardFilter { | |||
| class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter { | |||
| public: | |||
| using DeformableConvBackwardFilter::DeformableConvBackwardFilter; | |||
| void exec(_megdnn_tensor_in im,_megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
| _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, | |||
| void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | |||
| _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | |||
| _megdnn_tensor_out filter_grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| @@ -91,9 +83,11 @@ public: | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& im, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| @@ -103,6 +97,21 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& offset, | |||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -118,19 +127,6 @@ public: | |||
| _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const CanonizedFilterMeta& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| @@ -138,11 +134,14 @@ public: | |||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible); | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& im, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& out_grad, | |||
| const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, | |||
| const TensorLayout& mask_grad) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| @@ -152,6 +151,22 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, | |||
| const TensorLayout& mask_grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||
| LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&implicit_gemm); | |||
| all_algos.push_back(&batched_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardDataImpl) | |||
| LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; | |||
| LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| @@ -13,10 +13,14 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -25,6 +29,13 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_IMPLICIT_GEMM, | |||
| CUDA_BATCHED_MATMUL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| LocalShareBackwardDataImpl* opr; | |||
| @@ -77,6 +88,7 @@ public: | |||
| const char* name() const override { | |||
| return "LOCAL_SHARE_IMPLICIT_GEMM"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||
| }; | |||
| class LocalShareBackwardDataImpl::AlgoBatchedMatMul final | |||
| @@ -93,11 +105,11 @@ public: | |||
| const char* name() const override { | |||
| return "LOCAL_SHARE_BATCHED_MATMUL"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
| }; | |||
| class LocalShareBackwardDataImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class LocalShareBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -106,6 +118,7 @@ public: | |||
| AlgoBatchedMatMul batched_matmul; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -18,8 +18,14 @@ using namespace cuda; | |||
| LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&implicit_gemm); | |||
| all_algos.push_back(&batched_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardFilterImpl) | |||
| LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; | |||
| LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| @@ -13,10 +13,14 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -25,6 +29,12 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_IMPLICIT_GEMM, | |||
| CUDA_BATCHED_MATMUL, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| LocalShareBackwardFilterImpl* opr; | |||
| @@ -75,6 +85,7 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||
| }; | |||
| class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { | |||
| @@ -88,11 +99,11 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
| }; | |||
| class LocalShareBackwardFilterImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class LocalShareBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -101,6 +112,8 @@ public: | |||
| AlgoBatchedMatMul batched_matmul; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -19,8 +19,14 @@ LocalShareForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&batch_size_aware_chwn_small_image); | |||
| all_algos.push_back(&batch_size_aware_chwn); | |||
| all_algos.push_back(&batched_matmul); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl) | |||
| LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; | |||
| LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, | |||
| @@ -14,9 +14,13 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -25,6 +29,13 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_CHWN_BATCH_SIZE_AWARE, | |||
| CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE, | |||
| CUDA_BATCHED_MATMUL | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| LocalShareForwardImpl* opr; | |||
| @@ -79,6 +90,7 @@ public: | |||
| const char* name() const override { | |||
| return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE) | |||
| }; | |||
| class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final | |||
| @@ -95,6 +107,7 @@ public: | |||
| const char* name() const override { | |||
| return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE) | |||
| }; | |||
| class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { | |||
| @@ -108,11 +121,11 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
| }; | |||
| class LocalShareForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class LocalShareForwardImpl::AlgoPack : NonCopyableObj { | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -122,6 +135,7 @@ public: | |||
| AlgoBatchedMatMul batched_matmul; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -23,14 +23,6 @@ public: | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| @@ -41,7 +33,17 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| @@ -54,14 +56,6 @@ public: | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| @@ -71,6 +65,17 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -84,14 +89,6 @@ public: | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| class AlgoBase; | |||
| @@ -101,6 +98,17 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -11,6 +11,7 @@ | |||
| #include "./algos.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/common/algo_base.h" | |||
| #include <cuda.h> | |||
| #if CUDA_VERSION >= 10010 | |||
| @@ -33,10 +34,16 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | |||
| all_algos.push_back(cublas_bfloat16.get()); | |||
| #endif | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||
| MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | |||
| const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| @@ -67,4 +74,5 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,14 +6,18 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/matrix_mul/opr_impl.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include <unordered_map> | |||
| #include <cuda.h> | |||
| #include <memory> | |||
| #if CUDA_VERSION >= 10010 | |||
| @@ -32,6 +36,15 @@ protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| CUDA_CUBLAS, | |||
| CUDA_WMMA_UINT4X4X32, | |||
| CUDA_CUBLASLT, | |||
| CUDA_NAIVE, | |||
| CUDA_BFLOAT16 | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
| struct SizeArgs { | |||
| MatrixMulForwardImpl* opr; | |||
| @@ -62,12 +75,12 @@ public: | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) { | |||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| @@ -80,8 +93,6 @@ public: | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { | |||
| @@ -91,13 +102,10 @@ public: | |||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
| return 0_z; | |||
| } | |||
| const char* name() const override { | |||
| return "CUBLAS"; | |||
| } | |||
| const char* name() const override { return "CUBLAS"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
| }; | |||
| #if CUDA_VERSION >= 10000 | |||
| @@ -106,13 +114,10 @@ public: | |||
| AlgoUInt4x4x32WMMA() = default; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { | |||
| return "UINT4x4x32_WMMA"; | |||
| } | |||
| const char* name() const override { return "UINT4x4x32_WMMA"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||
| }; | |||
| #endif | |||
| #if CUDA_VERSION >= 10010 | |||
| @@ -120,13 +125,10 @@ class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { | |||
| return "CUBLAS_LT"; | |||
| } | |||
| const char* name() const override { return "CUBLAS_LT"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { | |||
| return true; | |||
| } | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||
| }; | |||
| #endif | |||
| @@ -140,6 +142,7 @@ public: | |||
| const char* name() const override { return "NAIVE"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| @@ -151,6 +154,13 @@ public: | |||
| const char* name() const override { return m_name.c_str(); } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_algorithm, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||
| @@ -160,9 +170,9 @@ private: | |||
| }; | |||
| #endif | |||
| class MatrixMulForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| @@ -178,6 +188,8 @@ public: | |||
| std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | |||
| #endif | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace cuda | |||
| @@ -82,7 +82,7 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
| args.opr->handle()->create_operator<MatrixMulForward>(); | |||
| matmul_opr->param() = args.opr->param(); | |||
| matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| matmul_opr->execution_policy() = {m_algorithm}; | |||
| matmul_opr->execution_policy() = {m_algorithm->info()}; | |||
| matmul_opr->exec(a, b, c, ctypecvt.workspace()); | |||
| } | |||
| ctypecvt.comp_to_dst_type(c, args.tensor_c); | |||
| @@ -25,15 +25,6 @@ public: | |||
| bool is_thread_safe() const override { return true; } | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "CUDA MATMUL"; | |||
| } | |||
| @@ -55,6 +46,17 @@ public: | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| private: | |||
| static AlgoPack sm_algo_pack; | |||
| @@ -10,10 +10,14 @@ | |||
| */ | |||
| #include "src/fallback/conv_bias/algos.h" | |||
| #include "src/fallback/conv_bias/conv1x1/algos.h" | |||
| #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||
| #include "src/fallback/conv_bias/im2col/algos.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/fallback/conv_bias/winograd/strategy.h" | |||
| #include "src/naive/convolution/helper.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "midout.h" | |||
| @@ -176,6 +180,7 @@ void kern_default(const ConvBiasImpl::NCBKernParam& p) { | |||
| } // namespace | |||
| MIDOUT_DECL(megdnn_fallback_naive) | |||
| /* ======================= AlgoNaive ======================== */ | |||
| bool ConvBiasImpl::AlgoNaive::usable( | |||
| @@ -36,6 +36,7 @@ public: | |||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
| return {support_data_type, AlgoCategory::NAIVE}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||
| }; | |||
| class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | |||
| @@ -59,6 +60,12 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
| @@ -87,6 +94,12 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
| @@ -115,6 +128,12 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
| @@ -143,6 +162,12 @@ public: | |||
| ConvAlgoTypePack get_algo_type() const override { | |||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
| @@ -155,6 +155,12 @@ using BiasMode = ConvBiasForward::BiasMode; | |||
| const NCBKernSizeParam& param) const override; \ | |||
| ConvAlgoTypePack get_algo_type() const override { \ | |||
| return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | |||
| } \ | |||
| std::string param() const override { \ | |||
| std::string ret; \ | |||
| serialize_write_pod(m_matmul_algo, ret); \ | |||
| serialize_write_pod(m_tile_size, ret); \ | |||
| return ret; \ | |||
| } \ | |||
| \ | |||
| private: \ | |||
| @@ -60,6 +60,13 @@ public: | |||
| return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
| AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| serialize_write_pod(m_oc_block_size, ret); | |||
| return ret; | |||
| } | |||
| protected: | |||
| size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
| @@ -43,6 +43,7 @@ public: | |||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
| return {support_data_type, AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1_GEMV) | |||
| protected: | |||
| size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
| @@ -68,6 +68,14 @@ public: | |||
| return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
| AlgoCategory::IM2COL}; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_IM2COL) | |||
| std::string param() const override { | |||
| std::string ret; | |||
| serialize_write_pod(m_matmul_algo, ret); | |||
| serialize_write_pod(m_ohw_tile_size, ret); | |||
| return ret; | |||
| } | |||
| private: | |||
| MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
| @@ -22,6 +22,14 @@ | |||
| #include "src/naive/convolution/algorithms.h" | |||
| #include "src/naive/handle.h" | |||
| #if MEGDNN_X86 | |||
| #include "src/x86/conv_bias/opr_impl.h" | |||
| #elif MEGDNN_AARCH64 | |||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||
| #elif MEGDNN_ARMV7 | |||
| #include "src/armv7/conv_bias/opr_impl.h" | |||
| #endif | |||
| #include <cstring> | |||
| using namespace megdnn; | |||
| @@ -65,17 +73,19 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { | |||
| class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| AlgoNaive algo_naive; | |||
| SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
| SmallVector<AlgoBase*> m_all_algos; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack() { | |||
| refhold.emplace_back(new AlgoConv1x1Gemv()); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| static CpuOprDelegationStorage<> storage; | |||
| auto matmul_opr = storage.get<MatrixMul>(); | |||
| auto&& matmul_algos = | |||
| static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | |||
| auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
| ->get_all_packed_algo(); | |||
| for (auto&& algo : matmul_algos) { | |||
| #if MEGDNN_X86 | |||
| //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | |||
| @@ -97,13 +107,13 @@ public: | |||
| refhold.emplace_back(new AlgoIm2col( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo), | |||
| ohw_tile_size)); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| for (size_t oc_tile_size : {48, 24}) { | |||
| refhold.emplace_back(new AlgoConv1x1( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo), | |||
| oc_tile_size)); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| #endif | |||
| @@ -113,26 +123,35 @@ public: | |||
| //! FIXME: I do not know a better way to do it. | |||
| refhold.emplace_back(new AlgoWinogradF32( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoWinogradF32_4x4( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoWinogradQS8( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| refhold.emplace_back(new AlgoWinogradQS8_8x8( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| m_all_algos.emplace_back(refhold.back().get()); | |||
| #endif | |||
| } | |||
| all_algos.emplace_back(&algo_naive); | |||
| m_all_algos.emplace_back(&algo_naive); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| SmallVector<AlgoBase*> all_algos; | |||
| const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
| static AlgoPack sl_algo_pack; | |||
| return sl_algo_pack.all_algos; | |||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| static AlgoPack algo_pack; | |||
| return algo_pack; | |||
| } | |||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||
| return algo_pack().all_algos(); | |||
| } | |||
| SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||
| @@ -140,7 +159,7 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||
| megdnn_assert(nr_type_contain(target_type.data_type), | |||
| "ConvBias algo selection only support one type"); | |||
| SmallVector<ConvBiasImpl::AlgoBase*> algos; | |||
| for (auto&& algo : algo_pack()) { | |||
| for (auto&& algo : get_all_packed_algo()) { | |||
| auto algo_type = algo->get_algo_type(); | |||
| if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||
| algo_type.algo_category == target_type.algo_category) { | |||
| @@ -166,7 +185,7 @@ void ConvBiasImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| workspace.size, preprocessed_filter); | |||
| auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | |||
| preprocessed_filter); | |||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam, workspace.size); | |||
| auto&& algo = get_algorithm(fparam, workspace.size); | |||
| if (!is_naive_algo(algo) && | |||
| NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { | |||
| exec_with_ncb_kern(fparam, algo); | |||
| @@ -189,9 +208,10 @@ void ConvBiasImpl::exec_preprocess(const TensorLayout& src_layout, | |||
| auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | |||
| preprocessed_filter); | |||
| //! should not pass workspace_size limit otherwise can not find match algo | |||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||
| if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo, | |||
| fparam) <= workspace.size) { | |||
| auto&& algo = get_algorithm(fparam); | |||
| if (!is_naive_algo(algo) && | |||
| NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= | |||
| workspace.size) { | |||
| exec_preprocess_with_ncb_kern(fparam, algo); | |||
| } else { | |||
| naive::ConvBiasForwardImpl::exec_preprocess( | |||
| @@ -207,7 +227,7 @@ size_t ConvBiasImpl::get_workspace_in_bytes( | |||
| const PreprocessedFilter* preprocessed_filter) { | |||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, | |||
| preprocessed_filter); | |||
| ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||
| auto&& algo = get_algorithm(fparam); | |||
| if (is_naive_algo(algo)) { | |||
| return naive::ConvBiasForwardImpl::get_workspace_in_bytes( | |||
| src, filter, bias, z, dst, preprocessed_filter); | |||
| @@ -221,7 +241,7 @@ size_t ConvBiasImpl::get_preprocess_workspace_in_bytes( | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) { | |||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | |||
| Algorithm* algo = get_algorithm(fparam); | |||
| auto&& algo = get_algorithm(fparam); | |||
| if (is_naive_algo(algo)) { | |||
| return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( | |||
| src, filter, bias, z, dst); | |||
| @@ -235,7 +255,7 @@ SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout( | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) { | |||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | |||
| Algorithm* algo = get_algorithm(fparam); | |||
| auto&& algo = get_algorithm(fparam); | |||
| if (is_naive_algo(algo)) { | |||
| return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( | |||
| src, filter, bias, z, dst); | |||
| @@ -443,7 +463,7 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
| MEGDNN_MARK_USED_VAR(param); | |||
| std::vector<Algorithm*> algos; | |||
| std::vector<Algorithm*> prefer_algos; | |||
| for (auto&& algo : algo_pack()) { | |||
| for (auto&& algo : get_all_packed_algo()) { | |||
| if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { | |||
| if (algo->is_preferred(param)) { | |||
| prefer_algos.push_back(algo); | |||
| @@ -457,10 +477,49 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
| return algos; | |||
| } | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
| const AlgorithmDesc& desc) const { | |||
| if (!desc.valid()) { | |||
| return nullptr; | |||
| } else { | |||
| switch (desc.handle_type) { | |||
| case Handle::HandleType::FALLBACK: { | |||
| const auto& map = algo_pack().all_algos_map(); | |||
| megdnn_assert(map.find(desc) != map.end()); | |||
| return map.at(desc); | |||
| }; | |||
| #if MEGDNN_X86 | |||
| case Handle::HandleType::X86: | |||
| return x86::ConvBiasImpl::get_algo_from_desc(desc); | |||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| case Handle::HandleType::ARM_COMMON: | |||
| return arm_common::ConvBiasImpl::get_algo_from_desc(desc); | |||
| #if MEGDNN_AARCH64 | |||
| case Handle::HandleType::AARCH64: | |||
| return aarch64::ConvBiasImpl::get_algo_from_desc(desc); | |||
| #else | |||
| case Handle::HandleType::ARMV7: | |||
| return armv7::ConvBiasImpl::get_algo_from_desc(desc); | |||
| #endif | |||
| #endif | |||
| case Handle::HandleType::NAIVE: { | |||
| auto algo = static_cast<naive::HandleImpl*>(handle()) | |||
| ->default_conv_bias_fwd_algo(); | |||
| megdnn_assert(algo->info().desc == desc); | |||
| return algo; | |||
| } | |||
| default: | |||
| megdnn_throw("Unknown handle type"); | |||
| return nullptr; | |||
| } | |||
| } | |||
| } | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||
| const NCBKernSizeParam& param, size_t workspace_size) { | |||
| if (auto set = execution_policy().algorithm) { | |||
| return set; | |||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
| return algo; | |||
| } | |||
| if (!m_prev_selected_algo || | |||
| memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | |||
| @@ -216,6 +216,86 @@ public: | |||
| AlgoBase() : Algorithm() { | |||
| m_handle_type = Handle::HandleType::FALLBACK; | |||
| } | |||
| enum class AlgoType : uint32_t { | |||
| //! fallback | |||
| FB_NAIVE = 1 << 0, | |||
| FB_WINOGRAD_F32, | |||
| FB_WINOGRAD_4X4_F32, | |||
| FB_WINOGRAD_QS8, | |||
| FB_WINOGRAD_8X8_QS8, | |||
| FB_CONV1x1, | |||
| FB_CONV1x1_GEMV, | |||
| FB_IM2COL, | |||
| #if MEGDNN_X86 | |||
| X86_DIRECT = 1 << 8, | |||
| X86_DIRECT_STRD2, | |||
| X86_WINOGRAD_F63_8x8_F32, | |||
| X86_WINOGRAD_F23_8x8_F32, | |||
| X86_MKLDNN, | |||
| X86_CHANWISE_AVX2_STRD1_QINT8, | |||
| X86_CHANWISE_AVX2_STRD2_QINT8, | |||
| X86_DIRECT_AVX2_STRD1_INT8, | |||
| X86_DIRECT_AVX2_STRD2_INT8, | |||
| X86_MKLDNN_QINT8, | |||
| X86_MKLDNN_MATMUL_QINT8, | |||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8, | |||
| ARM_COMMON_WINOGRAD_F45_FP16, | |||
| ARM_COMMON_WINOGRAD_F63_FP16, | |||
| ARM_COMMON_WINOGRAD_F23_8X8_FP16, | |||
| ARM_COMMON_DIRECT_FP16, | |||
| ARM_COMMON_DIRECT_STRD1_FP16, | |||
| ARM_COMMON_WINOGRAD_F23_4X4_FP32, | |||
| ARM_COMMON_WINOGRAD_F63_FP32, | |||
| ARM_COMMON_WINOGRAD_F63_4X4_FP32, | |||
| ARM_COMMON_WINOGRAD_F54_FP32, | |||
| ARM_COMMON_WINOGRAD_F45_FP32, | |||
| ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||
| ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||
| ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||
| ARM_COMMON_DIRECT_FP32, | |||
| ARM_COMMON_DIRECT_STRD1_FP32, | |||
| ARM_COMMON_DIRECT_STRD2_FP32, | |||
| ARM_COMMON_DIRECT_NCHW44_FP32, | |||
| ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||
| ARM_COMMON_CHWNWISE_NCHW44_F32, | |||
| ARM_COMMON_DIRECT_STRD1_S8, | |||
| ARM_COMMON_DIRECT_STRD2_S8, | |||
| ARM_COMMON_DIRECT_NCHW44, | |||
| ARM_COMMON_DIRECT_NCHW_NCHW44_S8, | |||
| ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | |||
| ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | |||
| ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | |||
| ARM_COMMON_DIRECT_STRD1_DOT_S8, | |||
| ARM_COMMON_DIRECT_STRD2_DOT_S8, | |||
| ARM_COMMON_DIRECT_NCHW44_DOT_S8, | |||
| ARM_COMMON_WINOGRAD_F23_8X8_S8, | |||
| ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32, | |||
| ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8, | |||
| ARM_COMMON_DIRECT_INT8X8X16, | |||
| ARM_COMMON_DIRECT_NCHW44_INT8X8X16, | |||
| ARM_COMMON_DIRECT_STRD2_INT8X8X16, | |||
| ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16, | |||
| ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16, | |||
| ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16, | |||
| ARM_COMMON_DIRECT_STRD1_QU8, | |||
| ARM_COMMON_DIRECT_STRD2_QU8, | |||
| ARM_COMMON_DIRECT_STRD1_DOT_QU8, | |||
| ARM_COMMON_DIRECT_STRD2_DOT_QU8, | |||
| #if MEGDNN_AARCH64 | |||
| AARCH64_DIRECT_STRD2_FP16, | |||
| AARCH64_DIRECT_STRD2_FP32, | |||
| AARCH64_MATMUL_S8, | |||
| AARCH64_MATMUL_QU8, | |||
| #else | |||
| ARMV7_MATMUL_S8, | |||
| ARMV7_MATMUL_QU8, | |||
| #endif // MEGDNN_AARCH64 | |||
| #endif | |||
| }; | |||
| virtual ~AlgoBase() = default; | |||
| virtual bool usable( | |||
| const NCBKernSizeParam& param, | |||
| @@ -255,12 +335,14 @@ public: | |||
| //! get the type of the algo | |||
| virtual ConvAlgoTypePack get_algo_type() const = 0; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| }; | |||
| using AlgoMapper = AlgoBase::Mapper; | |||
| /** | |||
| * \brief get all the algorithm for the opr. | |||
| */ | |||
| virtual SmallVector<AlgoBase*> algo_pack(); | |||
| virtual SmallVector<AlgoBase*> get_all_packed_algo(); | |||
| /** | |||
| * \brief select algo according to input algo type | |||
| @@ -305,6 +387,8 @@ private: | |||
| bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | |||
| Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
| //! get algorithm set by user or by heuristic | |||
| Algorithm* get_algorithm( | |||
| const NCBKernSizeParam& param, | |||
| @@ -320,6 +404,8 @@ private: | |||
| _megdnn_tensor_in bias, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace, | |||
| const PreprocessedFilter* preprocessed_filter); | |||
| static const AlgoPack& algo_pack(); | |||
| }; | |||
| inline bool is_enable_filter_preprocess( | |||