GitOrigin-RevId: dea03a0f7a
tags/v1.3.0
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * \file dnn/src/fallback/batched_matrix_mul/algos.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/fallback/batched_matrix_mul/algos.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/naive/handle.h" | |||||
| using namespace megdnn; | |||||
| using namespace fallback; | |||||
| BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&algo_default); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||||
| BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||||
| BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||||
| const TensorLayout& B, const TensorLayout& C) | |||||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||||
| BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||||
| BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||||
| tensor_a{A}, | |||||
| tensor_b{B}, | |||||
| tensor_c{C}, | |||||
| workspace{workspace} {} | |||||
| std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| auto&& param = opr->param(); | |||||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| MEGDNN_MARK_USED_VAR(m); | |||||
| MEGDNN_MARK_USED_VAR(n); | |||||
| MEGDNN_MARK_USED_VAR(k); | |||||
| return megdnn_mangle(ssprintf( | |||||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||||
| } | |||||
| /* ===================== default algo ===================== */ | |||||
| size_t BatchedMatrixMulForwardImpl::AlgoDefault::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||||
| auto A_ = args.layout_a.remove_axis(0), B_ = args.layout_b.remove_axis(0), | |||||
| C_ = args.layout_c.remove_axis(0); | |||||
| opr->param() = args.opr->param(); | |||||
| return opr->get_workspace_in_bytes(A_, B_, C_); | |||||
| } | |||||
| void BatchedMatrixMulForwardImpl::AlgoDefault::exec( | |||||
| const ExecArgs& args) const { | |||||
| //! As megbrain may modify param when checking all transpose situations, so | |||||
| //! here we should copy the param when dispatching kern | |||||
| auto param = args.opr->param(); | |||||
| auto kern = [args, param]() { | |||||
| auto N = args.layout_a.shape[0]; | |||||
| TensorND A_, B_, C_; | |||||
| A_.raw_ptr = args.tensor_a.raw_ptr; | |||||
| A_.layout = args.layout_a.remove_axis(0); | |||||
| B_.raw_ptr = args.tensor_b.raw_ptr; | |||||
| B_.layout = args.layout_b.remove_axis(0); | |||||
| C_.raw_ptr = args.tensor_c.raw_ptr; | |||||
| C_.layout = args.layout_c.remove_axis(0); | |||||
| auto Astrd = args.layout_a.dtype.size() * args.layout_a.stride[0], | |||||
| Bstrd = args.layout_b.dtype.size() * args.layout_b.stride[0], | |||||
| Cstrd = args.layout_c.dtype.size() * args.layout_c.stride[0]; | |||||
| auto advance_ptr = [](TensorND& dest, ptrdiff_t d) { | |||||
| dest.raw_ptr = | |||||
| static_cast<void*>(static_cast<dt_byte*>(dest.raw_ptr) + d); | |||||
| }; | |||||
| auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||||
| opr->param() = param; | |||||
| rep(n, N) { | |||||
| opr->exec(A_, B_, C_, args.workspace); | |||||
| advance_ptr(A_, Astrd); | |||||
| advance_ptr(B_, Bstrd); | |||||
| advance_ptr(C_, Cstrd); | |||||
| } | |||||
| }; | |||||
| static_cast<naive::HandleImpl*>(args.opr->handle())->dispatch_kern(kern); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,109 @@ | |||||
| /** | |||||
| * \file dnn/src/fallback/batched_matrix_mul/algos.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/batched_matrix_mul/opr_impl.h" | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| namespace megdnn { | |||||
| namespace fallback { | |||||
| /*! | |||||
| * \brief base class for matrix mul algos | |||||
| * | |||||
| */ | |||||
| class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| fallback_BLAS, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } | |||||
| struct SizeArgs { | |||||
| BatchedMatrixMulForwardImpl* opr; | |||||
| TensorLayout layout_a, layout_b, layout_c; | |||||
| std::string to_string() const; | |||||
| SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||||
| const TensorLayout& B, const TensorLayout& C); | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| TensorND tensor_a, tensor_b, tensor_c; | |||||
| Workspace workspace; | |||||
| ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs&) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert( | |||||
| req <= workspace.size, | |||||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| }; | |||||
| class BatchedMatrixMulForwardImpl::AlgoDefault final : public AlgoBase { | |||||
| public: | |||||
| AlgoDefault() = default; | |||||
| bool is_available(const SizeArgs&) const override { return true; } | |||||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | |||||
| const char* name() const override { return "DEFAULT"; } | |||||
| virtual void exec(const ExecArgs&) const override; | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) | |||||
| }; | |||||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| AlgoDefault algo_default; | |||||
| std::vector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | |||||
| } // namespace fallback | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -6,67 +6,61 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "./opr_impl.h" | #include "./opr_impl.h" | ||||
| #include "src/naive/handle.h" | |||||
| #include "./algos.h" | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | |||||
| #include "src/common/algo_chooser.h" | |||||
| #include "src/common/utils.cuh" | |||||
| #include "src/fallback/handle.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace fallback; | using namespace fallback; | ||||
| BatchedMatrixMulImpl::BatchedMatrixMulImpl(Handle *handle): | |||||
| BatchedMatrixMulForwardImpl(handle), | |||||
| m_storage(new CpuOprDelegationStorage<>), | |||||
| m_opr(m_storage->get<MatrixMul>()) | |||||
| { | |||||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||||
| } | } | ||||
| size_t BatchedMatrixMulImpl::get_workspace_in_bytes( | |||||
| const TensorLayout &A, const TensorLayout &B, | |||||
| const TensorLayout &C) { | |||||
| auto A_ = A.remove_axis(0), B_ = B.remove_axis(0), C_ = C.remove_axis(0); | |||||
| m_opr->param() = param(); | |||||
| return m_opr->get_workspace_in_bytes(A_, B_, C_); | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | |||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| if (sm_algo_pack.algo_default.is_available_reproducible( | |||||
| args, reproducible, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.algo_default; | |||||
| } | |||||
| if (reproducible) { | |||||
| return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } | |||||
| } | } | ||||
| void BatchedMatrixMulImpl::exec(_megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||||
| m_opr->param() = this->param(); | |||||
| auto kern = [this, A, B, C, workspace]() { | |||||
| auto N = A.layout.shape[0]; | |||||
| TensorND A_, B_, C_; | |||||
| A_.raw_ptr = A.raw_ptr; | |||||
| A_.layout = A.layout.remove_axis(0); | |||||
| B_.raw_ptr = B.raw_ptr; | |||||
| B_.layout = B.layout.remove_axis(0); | |||||
| C_.raw_ptr = C.raw_ptr; | |||||
| C_.layout = C.layout.remove_axis(0); | |||||
| auto Astrd = A.layout.dtype.size() * A.layout.stride[0], | |||||
| Bstrd = B.layout.dtype.size() * B.layout.stride[0], | |||||
| Cstrd = C.layout.dtype.size() * C.layout.stride[0]; | |||||
| auto advance_ptr = [](TensorND &dest, ptrdiff_t d) { | |||||
| dest.raw_ptr = static_cast<void*>( | |||||
| static_cast<dt_byte*>(dest.raw_ptr) + d); | |||||
| }; | |||||
| rep(n, N) { | |||||
| m_opr->exec(A_, B_, C_, workspace); | |||||
| advance_ptr(A_, Astrd); | |||||
| advance_ptr(B_, Bstrd); | |||||
| advance_ptr(C_, Cstrd); | |||||
| } | |||||
| }; | |||||
| static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | |||||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||||
| } | } | ||||
| void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||||
| algo->check_workspace(args, workspace).exec(args); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -15,26 +15,42 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace fallback { | namespace fallback { | ||||
| class BatchedMatrixMulImpl: public naive::BatchedMatrixMulForwardImpl { | |||||
| public: | |||||
| BatchedMatrixMulImpl(Handle *handle); | |||||
| void exec( | |||||
| _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout &A, | |||||
| const TensorLayout &B, | |||||
| const TensorLayout &C) override; | |||||
| private: | |||||
| std::unique_ptr<CpuOprDelegationStorage<>> m_storage; | |||||
| MatrixMulForward* m_opr; | |||||
| class BatchedMatrixMulForwardImpl: public naive::BatchedMatrixMulForwardImpl { | |||||
| public: | |||||
| using naive::BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl; | |||||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&) override; | |||||
| bool is_thread_safe() const override { return true; } | |||||
| class AlgoBase; | |||||
| class AlgoDefault; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| bool /*reproducible*/) override; | |||||
| const char* get_algorithm_set_name() const override { | |||||
| return "FALLBACK BATCHED MATMUL"; | |||||
| } | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace fallback | |||||
| } // namespace megdnn | |||||
| } // namespace fallback | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -473,6 +473,13 @@ public: | |||||
| PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
| "NoPackStrategyType::FLOAT16_FLOAT16"_hash); | "NoPackStrategyType::FLOAT16_FLOAT16"_hash); | ||||
| break; | break; | ||||
| #endif | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| case StrategyType::FLOAT_FP16: | |||||
| cb1(NCHW, NO_PACK, dt_float16, __fp16, | |||||
| PostprocessMode::NO_PROCESS, | |||||
| "NoPackStrategyType::FLOAT_FP16"_hash); | |||||
| break; | |||||
| #endif | #endif | ||||
| case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
| cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | ||||
| @@ -169,6 +169,10 @@ INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
| #endif | #endif | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #endif | |||||
| #undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -67,7 +67,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType) | |||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate) | ||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward) | ||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) | ||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMul) | |||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward) | |||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | ||||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC) | ||||
| @@ -10,13 +10,18 @@ | |||||
| */ | */ | ||||
| #include "src/fallback/matrix_mul/algos.h" | #include "src/fallback/matrix_mul/algos.h" | ||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
| #include "src/fallback/matrix_mul/gemv.h" | #include "src/fallback/matrix_mul/gemv.h" | ||||
| #include "src/fallback/matrix_mul/generic_strategy.h" | #include "src/fallback/matrix_mul/generic_strategy.h" | ||||
| #include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
| #include "midout.h" | #include "midout.h" | ||||
| MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | ||||
| MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | ||||
| MIDOUT_DECL(megdnn_fb_matmul_naive) | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace fallback; | using namespace fallback; | ||||
| @@ -39,6 +44,32 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| void kern_naive(const MatrixMulImpl::KernParam& kern_param) { | |||||
| MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { | |||||
| size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
| size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
| #define DISPATCH(TA, TB) \ | |||||
| if (kern_param.trA == TA && kern_param.trB == TB) { \ | |||||
| naive::dispatch_ta_tb<TA, TB>( \ | |||||
| kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ | |||||
| kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ | |||||
| kern_param.A_type, kern_param.B_type, kern_param.C_type, \ | |||||
| kern_param.format, kern_param.compute_mode); \ | |||||
| return; \ | |||||
| } | |||||
| DISPATCH(true, true); | |||||
| DISPATCH(true, false); | |||||
| DISPATCH(false, true); | |||||
| DISPATCH(false, false); | |||||
| #undef DISPATCH | |||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| } // anonymous namespace | } // anonymous namespace | ||||
| ////////////////////// AlgoF32K8x12x1 /////////////////////////// | ////////////////////// AlgoF32K8x12x1 /////////////////////////// | ||||
| @@ -84,11 +115,14 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | |||||
| bool MatrixMulImpl::AlgoGemv::usable( | bool MatrixMulImpl::AlgoGemv::usable( | ||||
| const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
| return !kern_size_param.trA && !kern_size_param.trB && | return !kern_size_param.trA && !kern_size_param.trB && | ||||
| kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||||
| !((kern_size_param.A_type.enumv() == | |||||
| kern_size_param.B_type.enumv()) && | |||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||||
| kern_size_param.format == | |||||
| param::MatrixMul::Format::DEFAULT && | |||||
| kern_size_param.compute_mode == | |||||
| param::MatrixMul::ComputeMode::DEFAULT && | |||||
| !((kern_size_param.A_type.enumv() == | |||||
| kern_size_param.B_type.enumv()) && | |||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||||
| } | } | ||||
| bool MatrixMulImpl::AlgoGemv::preferred( | bool MatrixMulImpl::AlgoGemv::preferred( | ||||
| @@ -128,4 +162,44 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern( | |||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| } | } | ||||
| /* ===================== naive algo ===================== */ | |||||
| bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const { | |||||
| return true; | |||||
| } | |||||
| bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const { | |||||
| return false; | |||||
| } | |||||
| size_t MatrixMulImpl::AlgoNaive::get_workspace( | |||||
| const KernSizeParam& kern_param) const { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_fb_matmul_naive, | |||||
| midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||||
| if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm || | |||||
| kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) { | |||||
| size_t ret = 0; | |||||
| if (kern_param.trA) { | |||||
| ret += kern_param.LDA * kern_param.K; | |||||
| } else { | |||||
| ret += kern_param.LDA * kern_param.M; | |||||
| } | |||||
| if (kern_param.trB) { | |||||
| ret += kern_param.LDB * kern_param.N; | |||||
| } else { | |||||
| ret += kern_param.LDB * kern_param.K; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern( | |||||
| const KernSizeParam&) const { | |||||
| return kern_naive; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -52,6 +52,28 @@ public: | |||||
| DEFAULT) | DEFAULT) | ||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoNaive final : public AlgoBase { | |||||
| public: | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "FB_NAIVE"; } | |||||
| bool usable(const KernSizeParam&) const override; | |||||
| bool preferred(const KernSizeParam&) const override; | |||||
| size_t get_workspace(const KernSizeParam&) const override; | |||||
| kern_t get_kern(const KernSizeParam&) const override; | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMM; } | |||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
| MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||||
| MEGDNN_OVERRIDE_MATMUL_DESC( | |||||
| 8, 16, 1, 4, | |||||
| static_cast<AlgoDataType>( | |||||
| static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||||
| static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||||
| static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||||
| static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)), | |||||
| DEFAULT) | |||||
| }; | |||||
| } // namespace fallback | } // namespace fallback | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -35,6 +35,7 @@ using namespace fallback; | |||||
| class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
| AlgoF32K8x12x1 f32_k8x12x1; | AlgoF32K8x12x1 f32_k8x12x1; | ||||
| AlgoGemv gemv; | AlgoGemv gemv; | ||||
| AlgoNaive naive; | |||||
| SmallVector<AlgoBase*> m_all_algos; | SmallVector<AlgoBase*> m_all_algos; | ||||
| AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
| @@ -42,6 +43,7 @@ public: | |||||
| AlgoPack() { | AlgoPack() { | ||||
| m_all_algos.emplace_back(&gemv); | m_all_algos.emplace_back(&gemv); | ||||
| m_all_algos.emplace_back(&f32_k8x12x1); | m_all_algos.emplace_back(&f32_k8x12x1); | ||||
| m_all_algos.emplace_back(&naive); | |||||
| for (auto&& algo : m_all_algos) { | for (auto&& algo : m_all_algos) { | ||||
| m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
| } | } | ||||
| @@ -147,19 +149,26 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||||
| algo_type.format = kern_size_param.format; | algo_type.format = kern_size_param.format; | ||||
| auto algos = select_algo_type(algo_type); | auto algos = select_algo_type(algo_type); | ||||
| Algorithm *heuristic_algo = nullptr; | Algorithm *heuristic_algo = nullptr; | ||||
| Algorithm *usable_algo = nullptr; | |||||
| for (auto&& algo : algos) { | for (auto&& algo : algos) { | ||||
| if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | ||||
| static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||||
| kern_size_param, reproducible) && | |||||
| static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | ||||
| workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
| if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||||
| return algo; | |||||
| } else if (!heuristic_algo) { | |||||
| heuristic_algo = algo; | |||||
| if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||||
| kern_size_param, reproducible)) { | |||||
| //! use gemv algo if it's prefered | |||||
| if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||||
| return algo; | |||||
| } else if (!heuristic_algo) { | |||||
| heuristic_algo = algo; | |||||
| } | |||||
| } else if (!usable_algo) { | |||||
| usable_algo = algo; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (!heuristic_algo) heuristic_algo = usable_algo; | |||||
| megdnn_assert(heuristic_algo, "No usable algorithm found"); | |||||
| return heuristic_algo; | return heuristic_algo; | ||||
| } | } | ||||
| @@ -110,6 +110,7 @@ public: | |||||
| //! fallback | //! fallback | ||||
| FB_F32K8x12x1 = 1 << 0, | FB_F32K8x12x1 = 1 << 0, | ||||
| FB_GEMV, | FB_GEMV, | ||||
| FB_NAIVE, | |||||
| #if MEGDNN_X86 | #if MEGDNN_X86 | ||||
| //! x86 | //! x86 | ||||
| @@ -233,6 +234,7 @@ public: | |||||
| private: | private: | ||||
| class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | ||||
| class AlgoGemv; | class AlgoGemv; | ||||
| class AlgoNaive; | |||||
| class AlgoPack; | class AlgoPack; | ||||
| //! maintain all the algos of in the opr of fallback | //! maintain all the algos of in the opr of fallback | ||||
| static const AlgoPack& algo_pack(); | static const AlgoPack& algo_pack(); | ||||
| @@ -141,20 +141,39 @@ void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | |||||
| } | } | ||||
| template <bool transA, bool transB> | template <bool transA, bool transB> | ||||
| void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace, | |||||
| const param::MatrixMul& param) { | |||||
| void exec_matrix_mul_quint4x4x32_helper( | |||||
| const void* A, const void* B, void* C, void* workspace, size_t M, | |||||
| size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||||
| DType A_type, DType B_type, DType C_type, | |||||
| const MatrixMul::Param::Format& format, | |||||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||||
| MEGDNN_MARK_USED_VAR(C_type); | |||||
| MEGDNN_MARK_USED_VAR(format); | |||||
| MEGDNN_MARK_USED_VAR(compute_mode); | |||||
| auto convert_layout = [](const TensorLayout& layout) { | auto convert_layout = [](const TensorLayout& layout) { | ||||
| auto ret = layout; | auto ret = layout; | ||||
| auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | ||||
| ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); | ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); | ||||
| return ret; | return ret; | ||||
| }; | }; | ||||
| TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||||
| TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||||
| convert_layout(B.layout)}; | |||||
| TensorLayout A_layout, B_layout; | |||||
| if (transA) { | |||||
| A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||||
| } else { | |||||
| A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||||
| } | |||||
| if (transB) { | |||||
| B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||||
| } else { | |||||
| B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||||
| } | |||||
| TensorND tensorA{const_cast<void*>(A), A_layout}; | |||||
| TensorND tensorB{const_cast<void*>(B), B_layout}; | |||||
| TensorND nA = {workspace, convert_layout(A_layout)}; | |||||
| TensorND nB = { | |||||
| static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||||
| convert_layout(B_layout)}; | |||||
| auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | ||||
| auto ptr = | auto ptr = | ||||
| static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte; | static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte; | ||||
| @@ -168,31 +187,48 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||||
| out_ptr[i + 1] = val1; | out_ptr[i + 1] = val1; | ||||
| } | } | ||||
| }; | }; | ||||
| convert_4to8(A, nA); | |||||
| convert_4to8(B, nB); | |||||
| auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||||
| auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||||
| LDC = C.layout.stride[0]; | |||||
| convert_4to8(tensorA, nA); | |||||
| convert_4to8(tensorB, nB); | |||||
| run_matrix_mul_tpl<uint8_t, dt_int32, transA, transB, dt_int32>( | run_matrix_mul_tpl<uint8_t, dt_int32, transA, transB, dt_int32>( | ||||
| nA.compatible_ptr<uint8_t>(), nB.compatible_ptr<uint8_t>(), | nA.compatible_ptr<uint8_t>(), nB.compatible_ptr<uint8_t>(), | ||||
| C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | |||||
| nA.layout.dtype, nB.layout.dtype); | |||||
| static_cast<dt_int32*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||||
| nB.layout.dtype); | |||||
| } | } | ||||
| template <bool transA, bool transB> | template <bool transA, bool transB> | ||||
| void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace, | |||||
| const param::MatrixMul& param) { | |||||
| void exec_matrix_mul_qint4x4x16_helper( | |||||
| const void* A, const void* B, void* C, void* workspace, size_t M, | |||||
| size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||||
| DType A_type, DType B_type, DType C_type, | |||||
| const MatrixMul::Param::Format& format, | |||||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||||
| MEGDNN_MARK_USED_VAR(C_type); | |||||
| MEGDNN_MARK_USED_VAR(format); | |||||
| MEGDNN_MARK_USED_VAR(compute_mode); | |||||
| auto convert_layout = [](const TensorLayout& layout) { | auto convert_layout = [](const TensorLayout& layout) { | ||||
| auto ret = layout; | auto ret = layout; | ||||
| auto param = layout.dtype.param<dtype::QuantizedS4>(); | auto param = layout.dtype.param<dtype::QuantizedS4>(); | ||||
| ret.dtype = dtype::QuantizedS8(param.scale); | ret.dtype = dtype::QuantizedS8(param.scale); | ||||
| return ret; | return ret; | ||||
| }; | }; | ||||
| TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||||
| TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||||
| convert_layout(B.layout)}; | |||||
| TensorLayout A_layout, B_layout; | |||||
| if (transA) { | |||||
| A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||||
| } else { | |||||
| A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||||
| } | |||||
| if (transB) { | |||||
| B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||||
| } else { | |||||
| B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||||
| } | |||||
| TensorND tensorA{const_cast<void*>(A), A_layout}; | |||||
| TensorND tensorB{const_cast<void*>(B), B_layout}; | |||||
| TensorND nA = {workspace, convert_layout(A_layout)}; | |||||
| TensorND nB = { | |||||
| static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||||
| convert_layout(B_layout)}; | |||||
| auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | ||||
| auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | ||||
| auto out_ptr = | auto out_ptr = | ||||
| @@ -204,18 +240,98 @@ void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| out_ptr[i + 1] = cur >> 4; | out_ptr[i + 1] = cur >> 4; | ||||
| } | } | ||||
| }; | }; | ||||
| convert_4to8(A, nA); | |||||
| convert_4to8(B, nB); | |||||
| auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||||
| auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||||
| LDC = C.layout.stride[0]; | |||||
| convert_4to8(tensorA, nA); | |||||
| convert_4to8(tensorB, nB); | |||||
| run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | ||||
| nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | ||||
| C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC, | |||||
| nA.layout.dtype, nB.layout.dtype); | |||||
| static_cast<dt_int16*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||||
| nB.layout.dtype); | |||||
| } | } | ||||
| template <bool TA, bool TB> | |||||
| void dispatch_ta_tb(const void* A, const void* B, void* C, void* workspace, | |||||
| size_t M, size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, | |||||
| ptrdiff_t LDC, DType A_type, DType B_type, DType C_type, | |||||
| const MatrixMul::Param::Format& format, | |||||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||||
| #define cb(_itype, _otype, _comp_type) \ | |||||
| if (format == param::MatrixMul::Format::DEFAULT) { \ | |||||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||||
| B_type); \ | |||||
| } else if (format == param::MatrixMul::Format::MK4) { \ | |||||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||||
| B_type); \ | |||||
| } else if (format == param::MatrixMul::Format::MK4_DOT) { \ | |||||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||||
| B_type); \ | |||||
| } else if (format == param::MatrixMul::Format::MK8) { \ | |||||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||||
| B_type); \ | |||||
| } | |||||
| if (A_type == dtype::Float32()) { | |||||
| cb(dt_float32, dt_float32, dt_float32); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| } else if (A_type == dtype::Float16()) { | |||||
| using Param = MatrixMul::Param; | |||||
| if (compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| cb(dt_float16, dt_float16, dt_float16); | |||||
| } else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| cb(dt_float16, dt_float16, dt_float32); | |||||
| } | |||||
| } else if (A_type == dtype::BFloat16()) { | |||||
| using Param = MatrixMul::Param; | |||||
| if (compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||||
| } else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||||
| } | |||||
| #endif | |||||
| } else if (A_type == dtype::Int8() && | |||||
| C_type == dtype::Int16()) { | |||||
| cb(dt_int8, dt_int16, dt_int16); | |||||
| } else if (A_type == dtype::Int16() && | |||||
| C_type == dtype::Int32()) { | |||||
| cb(dt_int16, dt_int32, dt_int32); | |||||
| } else if ((A_type == dtype::Int8() || | |||||
| A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (C_type == dtype::Int32() || | |||||
| C_type.enumv() == DTypeEnum::QuantizedS32)) { | |||||
| cb(dt_int8, dt_int32, dt_int32); | |||||
| } else if (A_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
| C_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
| cb(uint8_t, dt_int32, dt_int32); | |||||
| } else if (A_type.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| C_type.enumv() == DTypeEnum::QuantizedS32 && | |||||
| format == param::MatrixMul::Format::DEFAULT) { | |||||
| exec_matrix_mul_quint4x4x32_helper<TA, TB>( | |||||
| A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||||
| C_type, format, compute_mode); | |||||
| return; | |||||
| } else if (A_type.enumv() == DTypeEnum::QuantizedS4 && | |||||
| C_type.enumv() == DTypeEnum::QuantizedS16 && | |||||
| format == param::MatrixMul::Format::DEFAULT) { | |||||
| exec_matrix_mul_qint4x4x16_helper<TA, TB>( | |||||
| A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||||
| C_type, format, compute_mode); | |||||
| return; | |||||
| } | |||||
| #undef cb | |||||
| megdnn_throw( | |||||
| ssprintf("unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||||
| A_type.name(), B_type.name(), C_type.name(), | |||||
| static_cast<int>(compute_mode))); | |||||
| } | |||||
| } // namespace naive | } // namespace naive | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -45,77 +45,10 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | ||||
| LDC = C.layout.stride[0]; | LDC = C.layout.stride[0]; | ||||
| #define cb(_itype, _otype, _comp_type) \ | |||||
| if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK4) { \ | |||||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ | |||||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK8) { \ | |||||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } | |||||
| if (A.layout.dtype == dtype::Float32()) { | |||||
| cb(dt_float32, dt_float32, dt_float32); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| } else if (A.layout.dtype == dtype::Float16()) { | |||||
| using Param = MatrixMul::Param; | |||||
| if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| cb(dt_float16, dt_float16, dt_float16); | |||||
| } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| cb(dt_float16, dt_float16, dt_float32); | |||||
| } | |||||
| } else if (A.layout.dtype == dtype::BFloat16()) { | |||||
| using Param = MatrixMul::Param; | |||||
| if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||||
| } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||||
| } | |||||
| #endif | |||||
| } else if (A.layout.dtype == dtype::Int8() && | |||||
| C.layout.dtype == dtype::Int16()) { | |||||
| cb(dt_int8, dt_int16, dt_int16); | |||||
| } else if (A.layout.dtype == dtype::Int16() && | |||||
| C.layout.dtype == dtype::Int32()) { | |||||
| cb(dt_int16, dt_int32, dt_int32); | |||||
| } else if ((A.layout.dtype == dtype::Int8() || | |||||
| A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (C.layout.dtype == dtype::Int32() || | |||||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) { | |||||
| cb(dt_int8, dt_int32, dt_int32); | |||||
| } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
| cb(uint8_t, dt_int32, dt_int32); | |||||
| } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | |||||
| param.format == param::MatrixMul::Format::DEFAULT) { | |||||
| exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | |||||
| return; | |||||
| } else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && | |||||
| param.format == param::MatrixMul::Format::DEFAULT) { | |||||
| exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param); | |||||
| return; | |||||
| } | |||||
| #undef cb | |||||
| megdnn_throw(ssprintf( | |||||
| "unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||||
| A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(), | |||||
| static_cast<int>(param.compute_mode))); | |||||
| dispatch_ta_tb<TA, TB>(A.raw_ptr, B.raw_ptr, C.raw_ptr, workspace.raw_ptr, | |||||
| M, N, K, LDA, LDB, LDC, A.layout.dtype, | |||||
| B.layout.dtype, C.layout.dtype, param.format, | |||||
| param.compute_mode); | |||||
| } | } | ||||
| void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | ||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/batched_matrix_mul/algos.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/rocm/batched_matrix_mul/algos.h" | |||||
| #include "src/common/algo_base.h" | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&blas); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||||
| BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||||
| BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||||
| const TensorLayout& B, const TensorLayout& C) | |||||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||||
| BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||||
| BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||||
| tensor_a{A}, | |||||
| tensor_b{B}, | |||||
| tensor_c{C}, | |||||
| workspace{workspace} {} | |||||
| std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| auto&& param = opr->param(); | |||||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| MEGDNN_MARK_USED_VAR(m); | |||||
| MEGDNN_MARK_USED_VAR(n); | |||||
| MEGDNN_MARK_USED_VAR(k); | |||||
| return megdnn_mangle(ssprintf( | |||||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,118 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/batched_matrix_mul/algos.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/rocm/batched_matrix_mul/opr_impl.h" | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| namespace megdnn { | |||||
| namespace rocm { | |||||
| /*! | |||||
| * \brief base class for matrix mul algos | |||||
| * | |||||
| */ | |||||
| class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| ROCM_BLAS, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
| struct SizeArgs { | |||||
| BatchedMatrixMulForwardImpl* opr; | |||||
| TensorLayout layout_a, layout_b, layout_c; | |||||
| std::string to_string() const; | |||||
| SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||||
| const TensorLayout& B, const TensorLayout& C); | |||||
| bool can_be_treated_as_int8x8x32() const { | |||||
| return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||||
| (layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||||
| layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||||
| layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||||
| opr->param().format == param::MatrixMul::Format::DEFAULT; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| TensorND tensor_a, tensor_b, tensor_c; | |||||
| Workspace workspace; | |||||
| ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert( | |||||
| req <= workspace.size, | |||||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| }; | |||||
| class BatchedMatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||||
| public: | |||||
| AlgoBlas() = default; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||||
| return 0_z; | |||||
| } | |||||
| const char* name() const override { return "BLAS"; } | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||||
| }; | |||||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| AlgoBlas blas; | |||||
| std::vector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | |||||
| } // namespace rocm | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,140 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/batched_matrix_mul/Blas.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/rocm/batched_matrix_mul/algos.h" | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | |||||
| #include "src/rocm/handle.h" | |||||
| #include "src/rocm/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available( | |||||
| const SizeArgs& args) const { | |||||
| if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||||
| return false; | |||||
| if (args.layout_a.dtype == dtype::Float32() || | |||||
| args.layout_a.dtype == dtype::Float16()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||||
| auto batch = args.layout_a.shape[0]; | |||||
| auto m = args.layout_c.shape[1], n = args.layout_c.shape[2]; | |||||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2]; | |||||
| auto&& handle = concrete_handle(args.opr->handle()); | |||||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||||
| auto sgemm = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| rocblas_check(rocblas_sgemm_strided_batched( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||||
| (rocblas_int)(args.layout_b.stride[1]), | |||||
| (rocblas_int)(args.layout_b.stride[0]), | |||||
| args.tensor_a.ptr<dt_float32>(), | |||||
| (rocblas_int)(args.layout_a.stride[1]), | |||||
| (rocblas_int)(args.layout_a.stride[0]), zero, | |||||
| args.tensor_c.ptr<dt_float32>(), | |||||
| (rocblas_int)(args.layout_c.stride[1]), | |||||
| (rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch))); | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| //! used for FLOAT_IO16xC32, not tested | |||||
| auto gemm_ex = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| //! These two arguments for future use, see | |||||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| rocblas_check(rocblas_gemm_strided_batched_ex( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||||
| args.layout_b.stride[1], args.layout_b.stride[0], | |||||
| args.tensor_a.raw_ptr, rocblas_datatype_i8_r, | |||||
| args.layout_a.stride[1], args.layout_a.stride[0], zero, | |||||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||||
| args.layout_c.stride[1], args.layout_c.stride[0], | |||||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||||
| args.layout_c.stride[1], args.layout_c.stride[0], batch, | |||||
| rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||||
| solution_index, flags, &ws_size, nullptr)); | |||||
| MEGDNN_MARK_USED_VAR(ws_size); | |||||
| }; | |||||
| auto hgemm = [&]() { | |||||
| auto one_half = handle->one_device_h(); | |||||
| auto zero_half = handle->zero_device_h(); | |||||
| rocblas_check(rocblas_hgemm_strided_batched( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||||
| static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||||
| args.layout_b.stride[1], args.layout_b.stride[0], | |||||
| static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||||
| args.layout_a.stride[1], args.layout_a.stride[0], | |||||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||||
| static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||||
| args.layout_c.stride[1], args.layout_c.stride[0], batch)); | |||||
| }; | |||||
| #endif | |||||
| if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| if (args.layout_a.dtype == dtype::Float32()) { | |||||
| sgemm(); | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else { | |||||
| megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||||
| "invalid matmul data type"); | |||||
| hgemm(); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||||
| args.layout_c.dtype == dtype::Float16() && | |||||
| args.layout_a.dtype == dtype::Float16(), | |||||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||||
| "B, C are all Float16"); | |||||
| gemm_ex(); | |||||
| } | |||||
| #endif | |||||
| else { | |||||
| megdnn_throw("Unsupported data_type of matrix mul on rocm."); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,111 +10,58 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "./opr_impl.h" | #include "./opr_impl.h" | ||||
| #include "./algos.h" | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
| #include "src/common/algo_chooser.h" | |||||
| #include "src/common/utils.cuh" | #include "src/common/utils.cuh" | ||||
| #include "src/rocm/handle.h" | #include "src/rocm/handle.h" | ||||
| #include "src/rocm/utils.h" | #include "src/rocm/utils.h" | ||||
| namespace megdnn { | |||||
| namespace rocm { | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||||
| } | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | |||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, | |||||
| workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.blas; | |||||
| } | |||||
| if (reproducible) { | |||||
| return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } | |||||
| } | |||||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||||
| } | |||||
| void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | ||||
| _megdnn_tensor_out C, | _megdnn_tensor_out C, | ||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | check_exec(A.layout, B.layout, C.layout, workspace.size); | ||||
| auto dtype = A.layout.dtype; | |||||
| megdnn_assert(dtype.category() == DTypeCategory::FLOAT && | |||||
| param().format == param::MatrixMul::Format::DEFAULT); | |||||
| if (dtype == dtype::Float32() || | |||||
| MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) { | |||||
| auto batch = A.layout.shape[0]; | |||||
| auto m = C.layout.shape[1], n = C.layout.shape[2]; | |||||
| auto k = A.layout.shape[param().transposeA ? 1 : 2]; | |||||
| auto handle = concrete_handle(this->handle()); | |||||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||||
| auto io32_c32 = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| rocblas_check(rocblas_sgemm_strided_batched( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, B.ptr<dt_float32>(), | |||||
| (rocblas_int)(B.layout.stride[1]), | |||||
| (rocblas_int)(B.layout.stride[0]), A.ptr<dt_float32>(), | |||||
| (rocblas_int)(A.layout.stride[1]), | |||||
| (rocblas_int)(A.layout.stride[0]), zero, | |||||
| C.ptr<dt_float32>(), (rocblas_int)(C.layout.stride[1]), | |||||
| (rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch))); | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| auto io16_c32 = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| rocblas_check(rocblas_gemm_strided_batched_ex( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||||
| B.layout.stride[1], B.layout.stride[0], A.raw_ptr, | |||||
| rocblas_datatype_i8_r, A.layout.stride[1], | |||||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||||
| C.layout.stride[1], C.layout.stride[0], C.raw_ptr, | |||||
| rocblas_datatype_i32_r, C.layout.stride[1], | |||||
| C.layout.stride[0], batch, rocblas_datatype_i32_r, | |||||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||||
| nullptr)); | |||||
| }; | |||||
| auto io16_c16 = [&]() { | |||||
| auto zero_half = handle->zero_device_h(); | |||||
| auto one_half = handle->one_device_h(); | |||||
| rocblas_check(rocblas_hgemm_strided_batched( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||||
| static_cast<const rocblas_half*>(B.raw_ptr), | |||||
| B.layout.stride[1], B.layout.stride[0], | |||||
| static_cast<const rocblas_half*>(A.raw_ptr), | |||||
| A.layout.stride[1], A.layout.stride[0], | |||||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||||
| static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1], | |||||
| C.layout.stride[0], batch)); | |||||
| }; | |||||
| #endif | |||||
| if (dtype == dtype::Float32()) { | |||||
| io32_c32(); | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else { | |||||
| if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| io16_c32(); | |||||
| } else { | |||||
| io16_c16(); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||||
| algo->check_workspace(args, workspace).exec(args); | |||||
| } | } | ||||
| } // namespace rocm | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| @@ -17,36 +18,35 @@ namespace rocm { | |||||
| class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { | class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { | ||||
| public: | public: | ||||
| using BatchedMatrixMulForward::BatchedMatrixMulForward; | using BatchedMatrixMulForward::BatchedMatrixMulForward; | ||||
| BatchedMatrixMulForwardImpl(Handle* handle) | |||||
| : BatchedMatrixMul(handle), | |||||
| m_opr(handle->create_operator<MatrixMul>()) {} | |||||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| const TensorLayout&) override; | |||||
| bool is_thread_safe() const override { return true; } | |||||
| class AlgoBase; | |||||
| class AlgoBlas; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override { | |||||
| return {}; | |||||
| } | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
| const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
| size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
| bool /* reproducible */) override { | |||||
| return nullptr; | |||||
| } | |||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||||
| bool /*reproducible*/) override; | |||||
| bool is_thread_safe() const override { return true; } | |||||
| const char* get_algorithm_set_name() const override { | |||||
| return "ROCM BATCHED MATMUL"; | |||||
| } | |||||
| private: | |||||
| std::unique_ptr<MatrixMul> m_opr; | |||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace rocm | } // namespace rocm | ||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/matrix_mul/algos.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/rocm/matrix_mul/algos.h" | |||||
| #include "src/common/algo_base.h" | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| all_algos.push_back(&blas); | |||||
| for (auto&& algo : all_algos) { | |||||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||||
| } | |||||
| } | |||||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||||
| MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | |||||
| const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) | |||||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||||
| MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr, | |||||
| _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) | |||||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||||
| tensor_a{A}, | |||||
| tensor_b{B}, | |||||
| tensor_c{C}, | |||||
| workspace{workspace} {} | |||||
| std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| auto&& param = opr->param(); | |||||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||||
| MEGDNN_MARK_USED_VAR(m); | |||||
| MEGDNN_MARK_USED_VAR(n); | |||||
| MEGDNN_MARK_USED_VAR(k); | |||||
| return megdnn_mangle(ssprintf( | |||||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,118 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/matrix_mul/algos.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/metahelper.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/rocm/matrix_mul/opr_impl.h" | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| namespace megdnn { | |||||
| namespace rocm { | |||||
| /*! | |||||
| * \brief base class for matrix mul algos | |||||
| * | |||||
| */ | |||||
| class MatrixMulForwardImpl::AlgoBase : public Algorithm { | |||||
| protected: | |||||
| ~AlgoBase() = default; | |||||
| public: | |||||
| enum class AlgoType : uint32_t { | |||||
| ROCM_BLAS, | |||||
| }; | |||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
| struct SizeArgs { | |||||
| MatrixMulForwardImpl* opr; | |||||
| TensorLayout layout_a, layout_b, layout_c; | |||||
| std::string to_string() const; | |||||
| SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, | |||||
| const TensorLayout& B, const TensorLayout& C); | |||||
| bool can_be_treated_as_int8x8x32() const { | |||||
| return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||||
| (layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||||
| layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||||
| layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||||
| opr->param().format == param::MatrixMul::Format::DEFAULT; | |||||
| } | |||||
| }; | |||||
| struct ExecArgs : public SizeArgs { | |||||
| TensorND tensor_a, tensor_b, tensor_c; | |||||
| Workspace workspace; | |||||
| ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace); | |||||
| }; | |||||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void exec(const ExecArgs& args) const = 0; | |||||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||||
| } | |||||
| bool is_available_reproducible( | |||||
| const SizeArgs& args, bool reproducible = true, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||||
| return (!reproducible || is_reproducible()) && | |||||
| is_available_wk(args, limit); | |||||
| } | |||||
| AlgoBase& check_workspace(const SizeArgs& args, | |||||
| const Workspace& workspace) { | |||||
| auto req = get_workspace_in_bytes(args); | |||||
| megdnn_assert( | |||||
| req <= workspace.size, | |||||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||||
| name(), req, workspace.size); | |||||
| return *this; | |||||
| } | |||||
| }; | |||||
| class MatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||||
| public: | |||||
| AlgoBlas() = default; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||||
| return 0_z; | |||||
| } | |||||
| const char* name() const override { return "BLAS"; } | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||||
| }; | |||||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||||
| private: | |||||
| AlgoBase::Mapper m_all_algos_map; | |||||
| public: | |||||
| AlgoPack(); | |||||
| AlgoBlas blas; | |||||
| std::vector<AlgoBase*> all_algos; | |||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||||
| }; | |||||
| } // namespace rocm | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,162 @@ | |||||
| /** | |||||
| * \file dnn/src/rocm/matrix_mul/Blas.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/rocm/matrix_mul/algos.h" | |||||
| #include "hcc_detail/hcc_defs_prologue.h" | |||||
| #include "src/rocm/handle.h" | |||||
| #include "src/rocm/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| bool MatrixMulForwardImpl::AlgoBlas::is_available( | |||||
| const SizeArgs& args) const { | |||||
| if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||||
| return false; | |||||
| if (args.layout_a.dtype == dtype::Float32() || | |||||
| args.layout_a.dtype == dtype::Float16()) { | |||||
| return true; | |||||
| } else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||||
| args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||||
| //! see | |||||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||||
| bool rocblas_int8x8x32_valid = true; | |||||
| rocblas_int8x8x32_valid &= (k % 4 == 0); | |||||
| rocblas_int8x8x32_valid &= (!args.opr->param().transposeB || | |||||
| args.layout_b.stride[0] % 4 == 0); | |||||
| rocblas_int8x8x32_valid &= (!args.opr->param().transposeA || | |||||
| args.layout_a.stride[0] % 4 == 0); | |||||
| return rocblas_int8x8x32_valid; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void MatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||||
| auto m = args.layout_c.shape[0], n = args.layout_c.shape[1]; | |||||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||||
| auto&& handle = concrete_handle(args.opr->handle()); | |||||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||||
| auto sgemm = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| rocblas_check(rocblas_sgemm( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||||
| args.layout_b.stride[0], args.tensor_a.ptr<dt_float32>(), | |||||
| args.layout_a.stride[0], zero, args.tensor_c.ptr<dt_float32>(), | |||||
| args.layout_c.stride[0])); | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| //! used for FLOAT_IO16xC32, not tested | |||||
| auto gemm_ex = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| //! These two arguments for future use, see | |||||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| auto gemm_ex_err = rocblas_gemm_ex( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_f16_r, | |||||
| args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||||
| rocblas_datatype_f16_r, args.layout_a.stride[0], zero, | |||||
| args.tensor_c.raw_ptr, rocblas_datatype_f16_r, | |||||
| args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||||
| rocblas_datatype_f16_r, args.layout_c.stride[0], | |||||
| rocblas_datatype_f32_r, rocblas_gemm_algo_standard, | |||||
| solution_index, flags, &ws_size, nullptr); | |||||
| rocblas_check(gemm_ex_err); | |||||
| MEGDNN_MARK_USED_VAR(ws_size); | |||||
| }; | |||||
| auto hgemm = [&]() { | |||||
| auto one_half = handle->one_device_h(); | |||||
| auto zero_half = handle->zero_device_h(); | |||||
| auto hgemm_err = rocblas_hgemm( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||||
| static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||||
| args.layout_b.stride[0], | |||||
| static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||||
| args.layout_a.stride[0], | |||||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||||
| static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||||
| args.layout_c.stride[0]); | |||||
| rocblas_check(hgemm_err); | |||||
| }; | |||||
| #endif | |||||
| if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| if (args.layout_a.dtype == dtype::Float32()) { | |||||
| sgemm(); | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else { | |||||
| megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||||
| "invalid matmul data type"); | |||||
| hgemm(); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||||
| args.layout_c.dtype == dtype::Float16() && | |||||
| args.layout_a.dtype == dtype::Float16(), | |||||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||||
| "B, C are all Float16"); | |||||
| gemm_ex(); | |||||
| } | |||||
| #endif | |||||
| else { | |||||
| megdnn_assert(args.can_be_treated_as_int8x8x32()); | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| auto zero = handle->zero_device_i32(); | |||||
| auto one = handle->one_device_i32(); | |||||
| rocblas_check(rocblas_gemm_ex( | |||||
| rocblas_handle_, | |||||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||||
| args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||||
| rocblas_datatype_i8_r, args.layout_a.stride[0], zero, | |||||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||||
| args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||||
| rocblas_datatype_i32_r, args.layout_c.stride[0], | |||||
| rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||||
| solution_index, flags, &ws_size, nullptr)); | |||||
| MEGDNN_MARK_USED_VAR(ws_size); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -13,147 +13,53 @@ | |||||
| #include "src/rocm/utils.h" | #include "src/rocm/utils.h" | ||||
| #include "src/rocm/handle.h" | #include "src/rocm/handle.h" | ||||
| #include "./algos.h" | |||||
| #include "src/common/algo_chooser.h" | |||||
| namespace megdnn { | |||||
| namespace rocm { | |||||
| using namespace megdnn; | |||||
| using namespace rocm; | |||||
| void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||||
| _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) | |||||
| { | |||||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||||
| auto m = C.layout.shape[0], n = C.layout.shape[1]; | |||||
| auto k = A.layout.shape[param().transposeA ? 0 : 1]; | |||||
| auto handle = concrete_handle(this->handle()); | |||||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||||
| auto sgemm = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| rocblas_check(rocblas_sgemm( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, B.ptr<dt_float32>(), B.layout.stride[0], | |||||
| A.ptr<dt_float32>(), A.layout.stride[0], zero, | |||||
| C.ptr<dt_float32>(), C.layout.stride[0])); | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| //! used for FLOAT_IO16xC32, not tested | |||||
| auto gemm_ex = [&]() { | |||||
| auto zero = handle->zero_device(); | |||||
| auto one = handle->one_device(); | |||||
| //! These two arguments for future use, see | |||||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| auto gemm_ex_err = rocblas_gemm_ex( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r, | |||||
| B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r, | |||||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r, | |||||
| C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r, | |||||
| C.layout.stride[0], rocblas_datatype_f32_r, | |||||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||||
| nullptr); | |||||
| rocblas_check(gemm_ex_err); | |||||
| }; | |||||
| auto hgemm = [&]() { | |||||
| auto one_half = handle->one_device_h(); | |||||
| auto zero_half = handle->zero_device_h(); | |||||
| auto hgemm_err = rocblas_hgemm( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||||
| static_cast<const rocblas_half*>(B.raw_ptr), B.layout.stride[0], | |||||
| static_cast<const rocblas_half*>(A.raw_ptr), A.layout.stride[0], | |||||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||||
| static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[0]); | |||||
| rocblas_check(hgemm_err); | |||||
| }; | |||||
| #endif | |||||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||||
| MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | |||||
| } | |||||
| if (param().compute_mode == Param::ComputeMode::DEFAULT) { | |||||
| if (A.layout.dtype == dtype::Float32()) { | |||||
| sgemm(); | |||||
| } | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else { | |||||
| megdnn_assert(A.layout.dtype == dtype::Float16(), | |||||
| "invalid matmul data type"); | |||||
| hgemm(); | |||||
| } | |||||
| #endif | |||||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| if (sm_algo_pack.blas.is_available_reproducible( | |||||
| args, reproducible, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.blas; | |||||
| } | } | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| else if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||||
| megdnn_assert(B.layout.dtype == dtype::Float16() && | |||||
| C.layout.dtype == dtype::Float16() && | |||||
| A.layout.dtype == dtype::Float16(), | |||||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||||
| "B, C are all Float16"); | |||||
| gemm_ex(); | |||||
| } | |||||
| #endif | |||||
| else if (A.layout.dtype == dtype::Int8() && | |||||
| B.layout.dtype == dtype::Int8() && | |||||
| C.layout.dtype == dtype::Int32()) { | |||||
| //! see | |||||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||||
| bool rocblas_int8x8x32_valid = true; | |||||
| rocblas_int8x8x32_valid &= (k % 4 == 0); | |||||
| rocblas_int8x8x32_valid &= | |||||
| (!param().transposeB || B.layout.stride[0] % 4 == 0); | |||||
| rocblas_int8x8x32_valid &= | |||||
| (!param().transposeA || A.layout.stride[0] % 4 == 0); | |||||
| megdnn_assert(rocblas_int8x8x32_valid, | |||||
| "rocblas int8x8x32 matmul requires K must be a multiple " | |||||
| "of 4, and/or LDA/LDB based on transpose mode" | |||||
| "get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu", | |||||
| k, param().transposeB, B.layout.stride[0], | |||||
| param().transposeA, A.layout.stride[0]); | |||||
| int32_t solution_index = 0; | |||||
| uint32_t flags = 1; | |||||
| size_t ws_size = 0; | |||||
| auto zero = handle->zero_device_i32(); | |||||
| auto one = handle->one_device_i32(); | |||||
| rocblas_check(rocblas_gemm_ex( | |||||
| rocblas_handle_, | |||||
| param().transposeB ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| param().transposeA ? rocblas_operation_transpose | |||||
| : rocblas_operation_none, | |||||
| n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||||
| B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r, | |||||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||||
| C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r, | |||||
| C.layout.stride[0], rocblas_datatype_i32_r, | |||||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||||
| nullptr)); | |||||
| if (reproducible) { | |||||
| return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward"); | |||||
| } else { | } else { | ||||
| megdnn_assert((A.layout.dtype == dtype::Int8() && | |||||
| B.layout.dtype == dtype::Int8() && | |||||
| C.layout.dtype == dtype::Int16()), | |||||
| "invalid matmul data type"); | |||||
| megdnn_throw("cuda matmul does not support INT8x8x16 now"); | |||||
| return megdnn::get_usable_algo<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward"); | |||||
| } | } | ||||
| } | } | ||||
| } // namespace rocm | |||||
| } // namespace megdnn | |||||
| size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||||
| } | |||||
| void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||||
| algo->check_workspace(args, workspace).exec(args); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -20,29 +20,32 @@ public: | |||||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | ||||
| _megdnn_workspace workspace) override; | _megdnn_workspace workspace) override; | ||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| const TensorLayout&) override; | |||||
| bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
| class AlgoBase; | |||||
| class AlgoBlas; | |||||
| class AlgoPack; | |||||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
| private: | private: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override { | |||||
| return {}; | |||||
| } | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
| const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
| size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
| bool /*reproducible*/) override { | |||||
| return nullptr; | |||||
| } | |||||
| bool /*reproducible*/) override; | |||||
| const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
| return "ROCM MATMUL"; | return "ROCM MATMUL"; | ||||
| } | } | ||||
| static AlgoPack sm_algo_pack; | |||||
| }; | }; | ||||
| } // namespace rocm | } // namespace rocm | ||||
| @@ -46,6 +46,37 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { | |||||
| Checker<MatrixMul> checker(handle()); | |||||
| checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE")); | |||||
| using Param = MatrixMul::Param; | |||||
| auto args = matrix_mul::get_matmul_args(); | |||||
| for (auto arg : args) { | |||||
| auto m = arg.m, n = arg.n, k = arg.k; | |||||
| auto mask = arg.mask; | |||||
| Param param; | |||||
| param.transposeA = mask & 1; | |||||
| param.transposeB = mask & 2; | |||||
| TensorShape AS, BS, CS; | |||||
| if (param.transposeA) | |||||
| AS = TensorShape{k, m}; | |||||
| else | |||||
| AS = TensorShape{m, k}; | |||||
| if (param.transposeB) | |||||
| BS = TensorShape{n, k}; | |||||
| else | |||||
| BS = TensorShape{k, n}; | |||||
| CS = TensorShape{m, n}; | |||||
| TensorLayout AL, BL, CL; | |||||
| AL = TensorLayout(AS, dtype::Float32()); | |||||
| BL = TensorLayout(BS, dtype::Float32()); | |||||
| CL = TensorLayout(CS, dtype::Float32()); | |||||
| checker.set_param(param); | |||||
| checker.execl({AL, BL, CL}); | |||||
| } | |||||
| } | |||||
| TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { | TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { | ||||
| Checker<BatchedMatrixMul> checker(handle()); | Checker<BatchedMatrixMul> checker(handle()); | ||||
| @@ -232,7 +232,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||||
| 2, 5, 3, 3, 7, 4, -7, 1, | 2, 5, 3, 3, 7, 4, -7, 1, | ||||
| -5, 7, -4, -1, -1, 2, 4, 1, | -5, 7, -4, -1, -1, 2, 4, 1, | ||||
| 7, 2, -6, -2, -6, 3, 4, 4, | 7, 2, -6, -2, -6, 3, 4, 4, | ||||
| -2, 2, 3, 0, 6, 5, 3, 4, | |||||
| -2, 2, 3, 0, 6, 5, 3, 4, | |||||
| -1, -1, -5, 5, 2, 5, 1, 4, | -1, -1, -5, 5, 2, 5, 1, 4, | ||||
| 6, 2, 0, 0, 3, 2, 2, 1, | 6, 2, 0, 0, 3, 2, 2, 1, | ||||
| -4, -3, 7, 5, 0, 3, 2, 3}), | -4, -3, 7, 5, 0, 3, 2, 3}), | ||||
| @@ -243,7 +243,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||||
| 3, -1, 2, 2, 7, 3, 6, 0, | 3, -1, 2, 2, 7, 3, 6, 0, | ||||
| 5, 4, 0, 2, 2, 3, 3, 2, | 5, 4, 0, 2, 2, 3, 3, 2, | ||||
| 1, -8, -7, -6, 0, -5, -4, 4, | 1, -8, -7, -6, 0, -5, -4, 4, | ||||
| -3, 7, 1, 6, -2, 2, -1, 5, | |||||
| -3, 7, 1, 6, -2, 2, -1, 5, | |||||
| 2, 0, 7, 6, 5, 4, 3, 2, | 2, 0, 7, 6, 5, 4, 3, 2, | ||||
| 0, 0, 1, 0, 5, 2, 2, 6}), | 0, 0, 1, 0, 5, 2, 2, 6}), | ||||
| {}}, | {}}, | ||||