GitOrigin-RevId: dea03a0f7a
tags/v1.3.0
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * \file dnn/src/fallback/batched_matrix_mul/algos.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/fallback/batched_matrix_mul/algos.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&algo_default); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
| BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||
| const TensorLayout& B, const TensorLayout& C) | |||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
| BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||
| BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||
| tensor_a{A}, | |||
| tensor_b{B}, | |||
| tensor_c{C}, | |||
| workspace{workspace} {} | |||
| std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| auto&& param = opr->param(); | |||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
| MEGDNN_MARK_USED_VAR(m); | |||
| MEGDNN_MARK_USED_VAR(n); | |||
| MEGDNN_MARK_USED_VAR(k); | |||
| return megdnn_mangle(ssprintf( | |||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
| } | |||
| /* ===================== default algo ===================== */ | |||
| size_t BatchedMatrixMulForwardImpl::AlgoDefault::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||
| auto A_ = args.layout_a.remove_axis(0), B_ = args.layout_b.remove_axis(0), | |||
| C_ = args.layout_c.remove_axis(0); | |||
| opr->param() = args.opr->param(); | |||
| return opr->get_workspace_in_bytes(A_, B_, C_); | |||
| } | |||
| void BatchedMatrixMulForwardImpl::AlgoDefault::exec( | |||
| const ExecArgs& args) const { | |||
| //! As megbrain may modify param when checking all transpose situations, so | |||
| //! here we should copy the param when dispatching kern | |||
| auto param = args.opr->param(); | |||
| auto kern = [args, param]() { | |||
| auto N = args.layout_a.shape[0]; | |||
| TensorND A_, B_, C_; | |||
| A_.raw_ptr = args.tensor_a.raw_ptr; | |||
| A_.layout = args.layout_a.remove_axis(0); | |||
| B_.raw_ptr = args.tensor_b.raw_ptr; | |||
| B_.layout = args.layout_b.remove_axis(0); | |||
| C_.raw_ptr = args.tensor_c.raw_ptr; | |||
| C_.layout = args.layout_c.remove_axis(0); | |||
| auto Astrd = args.layout_a.dtype.size() * args.layout_a.stride[0], | |||
| Bstrd = args.layout_b.dtype.size() * args.layout_b.stride[0], | |||
| Cstrd = args.layout_c.dtype.size() * args.layout_c.stride[0]; | |||
| auto advance_ptr = [](TensorND& dest, ptrdiff_t d) { | |||
| dest.raw_ptr = | |||
| static_cast<void*>(static_cast<dt_byte*>(dest.raw_ptr) + d); | |||
| }; | |||
| auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||
| opr->param() = param; | |||
| rep(n, N) { | |||
| opr->exec(A_, B_, C_, args.workspace); | |||
| advance_ptr(A_, Astrd); | |||
| advance_ptr(B_, Bstrd); | |||
| advance_ptr(C_, Cstrd); | |||
| } | |||
| }; | |||
| static_cast<naive::HandleImpl*>(args.opr->handle())->dispatch_kern(kern); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * \file dnn/src/fallback/batched_matrix_mul/algos.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/batched_matrix_mul/opr_impl.h" | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| /*! | |||
| * \brief base class for matrix mul algos | |||
| * | |||
| */ | |||
| class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| fallback_BLAS, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::FALLBACK; } | |||
| struct SizeArgs { | |||
| BatchedMatrixMulForwardImpl* opr; | |||
| TensorLayout layout_a, layout_b, layout_c; | |||
| std::string to_string() const; | |||
| SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||
| const TensorLayout& B, const TensorLayout& C); | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| TensorND tensor_a, tensor_b, tensor_c; | |||
| Workspace workspace; | |||
| ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs&) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert( | |||
| req <= workspace.size, | |||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoDefault final : public AlgoBase { | |||
| public: | |||
| AlgoDefault() = default; | |||
| bool is_available(const SizeArgs&) const override { return true; } | |||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | |||
| const char* name() const override { return "DEFAULT"; } | |||
| virtual void exec(const ExecArgs&) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoDefault algo_default; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,67 +6,61 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| #include "./algos.h" | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/common/algo_chooser.h" | |||
| #include "src/common/utils.cuh" | |||
| #include "src/fallback/handle.h" | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| BatchedMatrixMulImpl::BatchedMatrixMulImpl(Handle *handle): | |||
| BatchedMatrixMulForwardImpl(handle), | |||
| m_storage(new CpuOprDelegationStorage<>), | |||
| m_opr(m_storage->get<MatrixMul>()) | |||
| { | |||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| size_t BatchedMatrixMulImpl::get_workspace_in_bytes( | |||
| const TensorLayout &A, const TensorLayout &B, | |||
| const TensorLayout &C) { | |||
| auto A_ = A.remove_axis(0), B_ = B.remove_axis(0), C_ = C.remove_axis(0); | |||
| m_opr->param() = param(); | |||
| return m_opr->get_workspace_in_bytes(A_, B_, C_); | |||
| BatchedMatrixMulForwardImpl::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| if (sm_algo_pack.algo_default.is_available_reproducible( | |||
| args, reproducible, workspace_limit_in_bytes)) { | |||
| return &sm_algo_pack.algo_default; | |||
| } | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "batched matrix mul forward"); | |||
| } else { | |||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "batched matrix mul forward"); | |||
| } | |||
| } | |||
| void BatchedMatrixMulImpl::exec(_megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
| m_opr->param() = this->param(); | |||
| auto kern = [this, A, B, C, workspace]() { | |||
| auto N = A.layout.shape[0]; | |||
| TensorND A_, B_, C_; | |||
| A_.raw_ptr = A.raw_ptr; | |||
| A_.layout = A.layout.remove_axis(0); | |||
| B_.raw_ptr = B.raw_ptr; | |||
| B_.layout = B.layout.remove_axis(0); | |||
| C_.raw_ptr = C.raw_ptr; | |||
| C_.layout = C.layout.remove_axis(0); | |||
| auto Astrd = A.layout.dtype.size() * A.layout.stride[0], | |||
| Bstrd = B.layout.dtype.size() * B.layout.stride[0], | |||
| Cstrd = C.layout.dtype.size() * C.layout.stride[0]; | |||
| auto advance_ptr = [](TensorND &dest, ptrdiff_t d) { | |||
| dest.raw_ptr = static_cast<void*>( | |||
| static_cast<dt_byte*>(dest.raw_ptr) + d); | |||
| }; | |||
| rep(n, N) { | |||
| m_opr->exec(A_, B_, C_, workspace); | |||
| advance_ptr(A_, Astrd); | |||
| advance_ptr(B_, Bstrd); | |||
| advance_ptr(C_, Cstrd); | |||
| } | |||
| }; | |||
| static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | |||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
| } | |||
| void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
| algo->check_workspace(args, workspace).exec(args); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -15,26 +15,42 @@ | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| class BatchedMatrixMulImpl: public naive::BatchedMatrixMulForwardImpl { | |||
| public: | |||
| BatchedMatrixMulImpl(Handle *handle); | |||
| void exec( | |||
| _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout &A, | |||
| const TensorLayout &B, | |||
| const TensorLayout &C) override; | |||
| private: | |||
| std::unique_ptr<CpuOprDelegationStorage<>> m_storage; | |||
| MatrixMulForward* m_opr; | |||
| class BatchedMatrixMulForwardImpl: public naive::BatchedMatrixMulForwardImpl { | |||
| public: | |||
| using naive::BatchedMatrixMulForwardImpl::BatchedMatrixMulForwardImpl; | |||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override; | |||
| bool is_thread_safe() const override { return true; } | |||
| class AlgoBase; | |||
| class AlgoDefault; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "FALLBACK BATCHED MATMUL"; | |||
| } | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -473,6 +473,13 @@ public: | |||
| PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::FLOAT16_FLOAT16"_hash); | |||
| break; | |||
| #endif | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| case StrategyType::FLOAT_FP16: | |||
| cb1(NCHW, NO_PACK, dt_float16, __fp16, | |||
| PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::FLOAT_FP16"_hash); | |||
| break; | |||
| #endif | |||
| case StrategyType::INT8x8x16: | |||
| cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| @@ -169,6 +169,10 @@ INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| #undef INSTANTIAL_CLASS | |||
| } // namespace megdnn | |||
| @@ -67,7 +67,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdate) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMul) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC) | |||
| @@ -10,13 +10,18 @@ | |||
| */ | |||
| #include "src/fallback/matrix_mul/algos.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "src/fallback/matrix_mul/gemm_impl.h" | |||
| #include "src/fallback/matrix_mul/gemv.h" | |||
| #include "src/fallback/matrix_mul/generic_strategy.h" | |||
| #include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | |||
| MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | |||
| MIDOUT_DECL(megdnn_fb_matmul_naive) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| @@ -39,6 +44,32 @@ void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| void kern_naive(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) { | |||
| size_t M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| #define DISPATCH(TA, TB) \ | |||
| if (kern_param.trA == TA && kern_param.trB == TB) { \ | |||
| naive::dispatch_ta_tb<TA, TB>( \ | |||
| kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \ | |||
| kern_param.workspace_ptr, M, N, K, LDA, LDB, LDC, \ | |||
| kern_param.A_type, kern_param.B_type, kern_param.C_type, \ | |||
| kern_param.format, kern_param.compute_mode); \ | |||
| return; \ | |||
| } | |||
| DISPATCH(true, true); | |||
| DISPATCH(true, false); | |||
| DISPATCH(false, true); | |||
| DISPATCH(false, false); | |||
| #undef DISPATCH | |||
| megdnn_assert_internal(0); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| ////////////////////// AlgoF32K8x12x1 /////////////////////////// | |||
| @@ -84,11 +115,14 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | |||
| bool MatrixMulImpl::AlgoGemv::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
| !((kern_size_param.A_type.enumv() == | |||
| kern_size_param.B_type.enumv()) && | |||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||
| kern_size_param.format == | |||
| param::MatrixMul::Format::DEFAULT && | |||
| kern_size_param.compute_mode == | |||
| param::MatrixMul::ComputeMode::DEFAULT && | |||
| !((kern_size_param.A_type.enumv() == | |||
| kern_size_param.B_type.enumv()) && | |||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int16) && | |||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32)); | |||
| } | |||
| bool MatrixMulImpl::AlgoGemv::preferred( | |||
| @@ -128,4 +162,44 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern( | |||
| megdnn_assert(0); | |||
| } | |||
| /* ===================== naive algo ===================== */ | |||
| bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const { | |||
| return true; | |||
| } | |||
| bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const { | |||
| return false; | |||
| } | |||
| size_t MatrixMulImpl::AlgoNaive::get_workspace( | |||
| const KernSizeParam& kern_param) const { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fb_matmul_naive, | |||
| midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||
| if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm || | |||
| kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) { | |||
| size_t ret = 0; | |||
| if (kern_param.trA) { | |||
| ret += kern_param.LDA * kern_param.K; | |||
| } else { | |||
| ret += kern_param.LDA * kern_param.M; | |||
| } | |||
| if (kern_param.trB) { | |||
| ret += kern_param.LDB * kern_param.N; | |||
| } else { | |||
| ret += kern_param.LDB * kern_param.K; | |||
| } | |||
| return ret; | |||
| } | |||
| return 0; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern( | |||
| const KernSizeParam&) const { | |||
| return kern_naive; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -52,6 +52,28 @@ public: | |||
| DEFAULT) | |||
| }; | |||
| class MatrixMulImpl::AlgoNaive final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "FB_NAIVE"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMM; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||
| MEGDNN_OVERRIDE_MATMUL_DESC( | |||
| 8, 16, 1, 4, | |||
| static_cast<AlgoDataType>( | |||
| static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
| static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
| static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
| static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
| static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)), | |||
| DEFAULT) | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| @@ -35,6 +35,7 @@ using namespace fallback; | |||
| class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoF32K8x12x1 f32_k8x12x1; | |||
| AlgoGemv gemv; | |||
| AlgoNaive naive; | |||
| SmallVector<AlgoBase*> m_all_algos; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| @@ -42,6 +43,7 @@ public: | |||
| AlgoPack() { | |||
| m_all_algos.emplace_back(&gemv); | |||
| m_all_algos.emplace_back(&f32_k8x12x1); | |||
| m_all_algos.emplace_back(&naive); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| @@ -147,19 +149,26 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||
| algo_type.format = kern_size_param.format; | |||
| auto algos = select_algo_type(algo_type); | |||
| Algorithm *heuristic_algo = nullptr; | |||
| Algorithm *usable_algo = nullptr; | |||
| for (auto&& algo : algos) { | |||
| if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | |||
| static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
| kern_size_param, reproducible) && | |||
| static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | |||
| workspace_limit_in_bytes) { | |||
| if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
| return algo; | |||
| } else if (!heuristic_algo) { | |||
| heuristic_algo = algo; | |||
| if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
| kern_size_param, reproducible)) { | |||
| //! use gemv algo if it's prefered | |||
| if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
| return algo; | |||
| } else if (!heuristic_algo) { | |||
| heuristic_algo = algo; | |||
| } | |||
| } else if (!usable_algo) { | |||
| usable_algo = algo; | |||
| } | |||
| } | |||
| } | |||
| if (!heuristic_algo) heuristic_algo = usable_algo; | |||
| megdnn_assert(heuristic_algo, "No usable algorithm found"); | |||
| return heuristic_algo; | |||
| } | |||
| @@ -110,6 +110,7 @@ public: | |||
| //! fallback | |||
| FB_F32K8x12x1 = 1 << 0, | |||
| FB_GEMV, | |||
| FB_NAIVE, | |||
| #if MEGDNN_X86 | |||
| //! x86 | |||
| @@ -233,6 +234,7 @@ public: | |||
| private: | |||
| class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
| class AlgoGemv; | |||
| class AlgoNaive; | |||
| class AlgoPack; | |||
| //! maintain all the algos of in the opr of fallback | |||
| static const AlgoPack& algo_pack(); | |||
| @@ -141,20 +141,39 @@ void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | |||
| } | |||
| template <bool transA, bool transB> | |||
| void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace, | |||
| const param::MatrixMul& param) { | |||
| void exec_matrix_mul_quint4x4x32_helper( | |||
| const void* A, const void* B, void* C, void* workspace, size_t M, | |||
| size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||
| DType A_type, DType B_type, DType C_type, | |||
| const MatrixMul::Param::Format& format, | |||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||
| MEGDNN_MARK_USED_VAR(C_type); | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| MEGDNN_MARK_USED_VAR(compute_mode); | |||
| auto convert_layout = [](const TensorLayout& layout) { | |||
| auto ret = layout; | |||
| auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | |||
| ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); | |||
| return ret; | |||
| }; | |||
| TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||
| TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||
| convert_layout(B.layout)}; | |||
| TensorLayout A_layout, B_layout; | |||
| if (transA) { | |||
| A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||
| } else { | |||
| A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||
| } | |||
| if (transB) { | |||
| B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||
| } else { | |||
| B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||
| } | |||
| TensorND tensorA{const_cast<void*>(A), A_layout}; | |||
| TensorND tensorB{const_cast<void*>(B), B_layout}; | |||
| TensorND nA = {workspace, convert_layout(A_layout)}; | |||
| TensorND nB = { | |||
| static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||
| convert_layout(B_layout)}; | |||
| auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||
| auto ptr = | |||
| static_cast<uint8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||
| @@ -168,31 +187,48 @@ void exec_matrix_mul_quint4x4x32_helper(_megdnn_tensor_in A, | |||
| out_ptr[i + 1] = val1; | |||
| } | |||
| }; | |||
| convert_4to8(A, nA); | |||
| convert_4to8(B, nB); | |||
| auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||
| auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
| LDC = C.layout.stride[0]; | |||
| convert_4to8(tensorA, nA); | |||
| convert_4to8(tensorB, nB); | |||
| run_matrix_mul_tpl<uint8_t, dt_int32, transA, transB, dt_int32>( | |||
| nA.compatible_ptr<uint8_t>(), nB.compatible_ptr<uint8_t>(), | |||
| C.compatible_ptr<dt_int32>(), M, N, K, LDA, LDB, LDC, | |||
| nA.layout.dtype, nB.layout.dtype); | |||
| static_cast<dt_int32*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||
| nB.layout.dtype); | |||
| } | |||
| template <bool transA, bool transB> | |||
| void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace, | |||
| const param::MatrixMul& param) { | |||
| void exec_matrix_mul_qint4x4x16_helper( | |||
| const void* A, const void* B, void* C, void* workspace, size_t M, | |||
| size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, ptrdiff_t LDC, | |||
| DType A_type, DType B_type, DType C_type, | |||
| const MatrixMul::Param::Format& format, | |||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||
| MEGDNN_MARK_USED_VAR(C_type); | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| MEGDNN_MARK_USED_VAR(compute_mode); | |||
| auto convert_layout = [](const TensorLayout& layout) { | |||
| auto ret = layout; | |||
| auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||
| ret.dtype = dtype::QuantizedS8(param.scale); | |||
| return ret; | |||
| }; | |||
| TensorND nA = {workspace.raw_ptr, convert_layout(A.layout)}; | |||
| TensorND nB = {workspace.raw_ptr + nA.layout.span().dist_byte(), | |||
| convert_layout(B.layout)}; | |||
| TensorLayout A_layout, B_layout; | |||
| if (transA) { | |||
| A_layout = TensorLayout({K, M}, {LDA, 1}, A_type); | |||
| } else { | |||
| A_layout = TensorLayout({M, K}, {LDA, 1}, A_type); | |||
| } | |||
| if (transB) { | |||
| B_layout = TensorLayout({N, K}, {LDB, 1}, B_type); | |||
| } else { | |||
| B_layout = TensorLayout({K, N}, {LDB, 1}, B_type); | |||
| } | |||
| TensorND tensorA{const_cast<void*>(A), A_layout}; | |||
| TensorND tensorB{const_cast<void*>(B), B_layout}; | |||
| TensorND nA = {workspace, convert_layout(A_layout)}; | |||
| TensorND nB = { | |||
| static_cast<uint8_t*>(workspace) + nA.layout.span().dist_byte(), | |||
| convert_layout(B_layout)}; | |||
| auto convert_4to8 = [](const TensorND& in, const TensorND& out) { | |||
| auto ptr = static_cast<int8_t*>(in.raw_ptr) + in.layout.span().low_byte; | |||
| auto out_ptr = | |||
| @@ -204,18 +240,98 @@ void exec_matrix_mul_qint4x4x16_helper(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| out_ptr[i + 1] = cur >> 4; | |||
| } | |||
| }; | |||
| convert_4to8(A, nA); | |||
| convert_4to8(B, nB); | |||
| auto M = C.layout.shape[0], N = C.layout.shape[1]; | |||
| auto K = A.layout.shape[param.transposeA ? 0 : 1]; | |||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
| LDC = C.layout.stride[0]; | |||
| convert_4to8(tensorA, nA); | |||
| convert_4to8(tensorB, nB); | |||
| run_matrix_mul_tpl<int8_t, dt_int16, transA, transB, dt_int16>( | |||
| nA.compatible_ptr<int8_t>(), nB.compatible_ptr<int8_t>(), | |||
| C.compatible_ptr<dt_int16>(), M, N, K, LDA, LDB, LDC, | |||
| nA.layout.dtype, nB.layout.dtype); | |||
| static_cast<dt_int16*>(C), M, N, K, LDA, LDB, LDC, nA.layout.dtype, | |||
| nB.layout.dtype); | |||
| } | |||
| template <bool TA, bool TB> | |||
| void dispatch_ta_tb(const void* A, const void* B, void* C, void* workspace, | |||
| size_t M, size_t N, size_t K, ptrdiff_t LDA, ptrdiff_t LDB, | |||
| ptrdiff_t LDC, DType A_type, DType B_type, DType C_type, | |||
| const MatrixMul::Param::Format& format, | |||
| const MatrixMul::Param::ComputeMode& compute_mode) { | |||
| #define cb(_itype, _otype, _comp_type) \ | |||
| if (format == param::MatrixMul::Format::DEFAULT) { \ | |||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
| B_type); \ | |||
| } else if (format == param::MatrixMul::Format::MK4) { \ | |||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
| B_type); \ | |||
| } else if (format == param::MatrixMul::Format::MK4_DOT) { \ | |||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
| B_type); \ | |||
| } else if (format == param::MatrixMul::Format::MK8) { \ | |||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, \ | |||
| B_type); \ | |||
| } | |||
| if (A_type == dtype::Float32()) { | |||
| cb(dt_float32, dt_float32, dt_float32); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| } else if (A_type == dtype::Float16()) { | |||
| using Param = MatrixMul::Param; | |||
| if (compute_mode == Param::ComputeMode::DEFAULT) { | |||
| cb(dt_float16, dt_float16, dt_float16); | |||
| } else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||
| cb(dt_float16, dt_float16, dt_float32); | |||
| } | |||
| } else if (A_type == dtype::BFloat16()) { | |||
| using Param = MatrixMul::Param; | |||
| if (compute_mode == Param::ComputeMode::DEFAULT) { | |||
| cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||
| } else if (compute_mode == Param::ComputeMode::FLOAT32) { | |||
| cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||
| } | |||
| #endif | |||
| } else if (A_type == dtype::Int8() && | |||
| C_type == dtype::Int16()) { | |||
| cb(dt_int8, dt_int16, dt_int16); | |||
| } else if (A_type == dtype::Int16() && | |||
| C_type == dtype::Int32()) { | |||
| cb(dt_int16, dt_int32, dt_int32); | |||
| } else if ((A_type == dtype::Int8() || | |||
| A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
| (C_type == dtype::Int32() || | |||
| C_type.enumv() == DTypeEnum::QuantizedS32)) { | |||
| cb(dt_int8, dt_int32, dt_int32); | |||
| } else if (A_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
| C_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| cb(uint8_t, dt_int32, dt_int32); | |||
| } else if (A_type.enumv() == DTypeEnum::Quantized4Asymm && | |||
| C_type.enumv() == DTypeEnum::QuantizedS32 && | |||
| format == param::MatrixMul::Format::DEFAULT) { | |||
| exec_matrix_mul_quint4x4x32_helper<TA, TB>( | |||
| A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||
| C_type, format, compute_mode); | |||
| return; | |||
| } else if (A_type.enumv() == DTypeEnum::QuantizedS4 && | |||
| C_type.enumv() == DTypeEnum::QuantizedS16 && | |||
| format == param::MatrixMul::Format::DEFAULT) { | |||
| exec_matrix_mul_qint4x4x16_helper<TA, TB>( | |||
| A, B, C, workspace, M, N, K, LDA, LDB, LDC, A_type, B_type, | |||
| C_type, format, compute_mode); | |||
| return; | |||
| } | |||
| #undef cb | |||
| megdnn_throw( | |||
| ssprintf("unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||
| A_type.name(), B_type.name(), C_type.name(), | |||
| static_cast<int>(compute_mode))); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -45,77 +45,10 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | |||
| LDC = C.layout.stride[0]; | |||
| #define cb(_itype, _otype, _comp_type) \ | |||
| if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
| A.layout.dtype, B.layout.dtype); \ | |||
| } else if (param.format == param::MatrixMul::Format::MK4) { \ | |||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
| A.layout.dtype, B.layout.dtype); \ | |||
| } else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ | |||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
| A.layout.dtype, B.layout.dtype); \ | |||
| } else if (param.format == param::MatrixMul::Format::MK8) { \ | |||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||
| A.layout.dtype, B.layout.dtype); \ | |||
| } | |||
| if (A.layout.dtype == dtype::Float32()) { | |||
| cb(dt_float32, dt_float32, dt_float32); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| } else if (A.layout.dtype == dtype::Float16()) { | |||
| using Param = MatrixMul::Param; | |||
| if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||
| cb(dt_float16, dt_float16, dt_float16); | |||
| } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||
| cb(dt_float16, dt_float16, dt_float32); | |||
| } | |||
| } else if (A.layout.dtype == dtype::BFloat16()) { | |||
| using Param = MatrixMul::Param; | |||
| if (param.compute_mode == Param::ComputeMode::DEFAULT) { | |||
| cb(dt_bfloat16, dt_bfloat16, dt_bfloat16); | |||
| } else if (param.compute_mode == Param::ComputeMode::FLOAT32) { | |||
| cb(dt_bfloat16, dt_bfloat16, dt_float32); | |||
| } | |||
| #endif | |||
| } else if (A.layout.dtype == dtype::Int8() && | |||
| C.layout.dtype == dtype::Int16()) { | |||
| cb(dt_int8, dt_int16, dt_int16); | |||
| } else if (A.layout.dtype == dtype::Int16() && | |||
| C.layout.dtype == dtype::Int32()) { | |||
| cb(dt_int16, dt_int32, dt_int32); | |||
| } else if ((A.layout.dtype == dtype::Int8() || | |||
| A.layout.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
| (C.layout.dtype == dtype::Int32() || | |||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32)) { | |||
| cb(dt_int8, dt_int32, dt_int32); | |||
| } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && | |||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
| cb(uint8_t, dt_int32, dt_int32); | |||
| } else if (A.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && | |||
| param.format == param::MatrixMul::Format::DEFAULT) { | |||
| exec_matrix_mul_quint4x4x32_helper<TA, TB>(A, B, C, workspace, param); | |||
| return; | |||
| } else if (A.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
| C.layout.dtype.enumv() == DTypeEnum::QuantizedS16 && | |||
| param.format == param::MatrixMul::Format::DEFAULT) { | |||
| exec_matrix_mul_qint4x4x16_helper<TA, TB>(A, B, C, workspace, param); | |||
| return; | |||
| } | |||
| #undef cb | |||
| megdnn_throw(ssprintf( | |||
| "unsupported naive MatrixMul(%s, %s) -> %s (cmode = %d)", | |||
| A.layout.dtype.name(), B.layout.dtype.name(), C.layout.dtype.name(), | |||
| static_cast<int>(param.compute_mode))); | |||
| dispatch_ta_tb<TA, TB>(A.raw_ptr, B.raw_ptr, C.raw_ptr, workspace.raw_ptr, | |||
| M, N, K, LDA, LDB, LDC, A.layout.dtype, | |||
| B.layout.dtype, C.layout.dtype, param.format, | |||
| param.compute_mode); | |||
| } | |||
| void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * \file dnn/src/rocm/batched_matrix_mul/algos.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/rocm/batched_matrix_mul/algos.h" | |||
| #include "src/common/algo_base.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&blas); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| BatchedMatrixMulForwardImpl::AlgoPack BatchedMatrixMulForwardImpl::sm_algo_pack; | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
| BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| BatchedMatrixMulForwardImpl* o, const TensorLayout& A, | |||
| const TensorLayout& B, const TensorLayout& C) | |||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
| BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( | |||
| BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, _megdnn_tensor_out C, _megdnn_workspace workspace) | |||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||
| tensor_a{A}, | |||
| tensor_b{B}, | |||
| tensor_c{C}, | |||
| workspace{workspace} {} | |||
| std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| auto&& param = opr->param(); | |||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
| MEGDNN_MARK_USED_VAR(m); | |||
| MEGDNN_MARK_USED_VAR(n); | |||
| MEGDNN_MARK_USED_VAR(k); | |||
| return megdnn_mangle(ssprintf( | |||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * \file dnn/src/rocm/batched_matrix_mul/algos.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/rocm/batched_matrix_mul/opr_impl.h" | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| /*! | |||
| * \brief base class for matrix mul algos | |||
| * | |||
| */ | |||
| class BatchedMatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| ROCM_BLAS, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||
| struct SizeArgs { | |||
| BatchedMatrixMulForwardImpl* opr; | |||
| TensorLayout layout_a, layout_b, layout_c; | |||
| std::string to_string() const; | |||
| SizeArgs(BatchedMatrixMulForwardImpl* opr, const TensorLayout& A, | |||
| const TensorLayout& B, const TensorLayout& C); | |||
| bool can_be_treated_as_int8x8x32() const { | |||
| return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||
| (layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
| layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
| (layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||
| layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||
| opr->param().format == param::MatrixMul::Format::DEFAULT; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| TensorND tensor_a, tensor_b, tensor_c; | |||
| Workspace workspace; | |||
| ExecArgs(BatchedMatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert( | |||
| req <= workspace.size, | |||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||
| public: | |||
| AlgoBlas() = default; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
| return 0_z; | |||
| } | |||
| const char* name() const override { return "BLAS"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
| }; | |||
| class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoBlas blas; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,140 @@ | |||
| /** | |||
| * \file dnn/src/rocm/batched_matrix_mul/Blas.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/rocm/batched_matrix_mul/algos.h" | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/rocm/handle.h" | |||
| #include "src/rocm/utils.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| bool BatchedMatrixMulForwardImpl::AlgoBlas::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||
| return false; | |||
| if (args.layout_a.dtype == dtype::Float32() || | |||
| args.layout_a.dtype == dtype::Float16()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void BatchedMatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||
| auto batch = args.layout_a.shape[0]; | |||
| auto m = args.layout_c.shape[1], n = args.layout_c.shape[2]; | |||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 1 : 2]; | |||
| auto&& handle = concrete_handle(args.opr->handle()); | |||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
| auto sgemm = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| rocblas_check(rocblas_sgemm_strided_batched( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||
| (rocblas_int)(args.layout_b.stride[1]), | |||
| (rocblas_int)(args.layout_b.stride[0]), | |||
| args.tensor_a.ptr<dt_float32>(), | |||
| (rocblas_int)(args.layout_a.stride[1]), | |||
| (rocblas_int)(args.layout_a.stride[0]), zero, | |||
| args.tensor_c.ptr<dt_float32>(), | |||
| (rocblas_int)(args.layout_c.stride[1]), | |||
| (rocblas_int)(args.layout_c.stride[0]), (rocblas_int)(batch))); | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| //! used for FLOAT_IO16xC32, not tested | |||
| auto gemm_ex = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| //! These two arguments for future use, see | |||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| rocblas_check(rocblas_gemm_strided_batched_ex( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||
| args.layout_b.stride[1], args.layout_b.stride[0], | |||
| args.tensor_a.raw_ptr, rocblas_datatype_i8_r, | |||
| args.layout_a.stride[1], args.layout_a.stride[0], zero, | |||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
| args.layout_c.stride[1], args.layout_c.stride[0], | |||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
| args.layout_c.stride[1], args.layout_c.stride[0], batch, | |||
| rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||
| solution_index, flags, &ws_size, nullptr)); | |||
| MEGDNN_MARK_USED_VAR(ws_size); | |||
| }; | |||
| auto hgemm = [&]() { | |||
| auto one_half = handle->one_device_h(); | |||
| auto zero_half = handle->zero_device_h(); | |||
| rocblas_check(rocblas_hgemm_strided_batched( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
| static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||
| args.layout_b.stride[1], args.layout_b.stride[0], | |||
| static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||
| args.layout_a.stride[1], args.layout_a.stride[0], | |||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||
| static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||
| args.layout_c.stride[1], args.layout_c.stride[0], batch)); | |||
| }; | |||
| #endif | |||
| if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
| if (args.layout_a.dtype == dtype::Float32()) { | |||
| sgemm(); | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else { | |||
| megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||
| "invalid matmul data type"); | |||
| hgemm(); | |||
| } | |||
| #endif | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
| megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||
| args.layout_c.dtype == dtype::Float16() && | |||
| args.layout_a.dtype == dtype::Float16(), | |||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
| "B, C are all Float16"); | |||
| gemm_ex(); | |||
| } | |||
| #endif | |||
| else { | |||
| megdnn_throw("Unsupported data_type of matrix mul on rocm."); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,111 +10,58 @@ | |||
| * implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "./algos.h" | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/common/algo_chooser.h" | |||
| #include "src/common/utils.cuh" | |||
| #include "src/rocm/handle.h" | |||
| #include "src/rocm/utils.h" | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| BatchedMatrixMulForwardImpl::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, | |||
| workspace_limit_in_bytes)) { | |||
| return &sm_algo_pack.blas; | |||
| } | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "batched matrix mul forward"); | |||
| } else { | |||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "batched matrix mul forward"); | |||
| } | |||
| } | |||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
| } | |||
| void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
| auto dtype = A.layout.dtype; | |||
| megdnn_assert(dtype.category() == DTypeCategory::FLOAT && | |||
| param().format == param::MatrixMul::Format::DEFAULT); | |||
| if (dtype == dtype::Float32() || | |||
| MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) { | |||
| auto batch = A.layout.shape[0]; | |||
| auto m = C.layout.shape[1], n = C.layout.shape[2]; | |||
| auto k = A.layout.shape[param().transposeA ? 1 : 2]; | |||
| auto handle = concrete_handle(this->handle()); | |||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
| auto io32_c32 = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| rocblas_check(rocblas_sgemm_strided_batched( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, B.ptr<dt_float32>(), | |||
| (rocblas_int)(B.layout.stride[1]), | |||
| (rocblas_int)(B.layout.stride[0]), A.ptr<dt_float32>(), | |||
| (rocblas_int)(A.layout.stride[1]), | |||
| (rocblas_int)(A.layout.stride[0]), zero, | |||
| C.ptr<dt_float32>(), (rocblas_int)(C.layout.stride[1]), | |||
| (rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch))); | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| auto io16_c32 = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| rocblas_check(rocblas_gemm_strided_batched_ex( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||
| B.layout.stride[1], B.layout.stride[0], A.raw_ptr, | |||
| rocblas_datatype_i8_r, A.layout.stride[1], | |||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||
| C.layout.stride[1], C.layout.stride[0], C.raw_ptr, | |||
| rocblas_datatype_i32_r, C.layout.stride[1], | |||
| C.layout.stride[0], batch, rocblas_datatype_i32_r, | |||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
| nullptr)); | |||
| }; | |||
| auto io16_c16 = [&]() { | |||
| auto zero_half = handle->zero_device_h(); | |||
| auto one_half = handle->one_device_h(); | |||
| rocblas_check(rocblas_hgemm_strided_batched( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
| static_cast<const rocblas_half*>(B.raw_ptr), | |||
| B.layout.stride[1], B.layout.stride[0], | |||
| static_cast<const rocblas_half*>(A.raw_ptr), | |||
| A.layout.stride[1], A.layout.stride[0], | |||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||
| static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1], | |||
| C.layout.stride[0], batch)); | |||
| }; | |||
| #endif | |||
| if (dtype == dtype::Float32()) { | |||
| io32_c32(); | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else { | |||
| if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
| io16_c32(); | |||
| } else { | |||
| io16_c16(); | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
| algo->check_workspace(args, workspace).exec(args); | |||
| } | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| @@ -17,36 +18,35 @@ namespace rocm { | |||
| class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { | |||
| public: | |||
| using BatchedMatrixMulForward::BatchedMatrixMulForward; | |||
| BatchedMatrixMulForwardImpl(Handle* handle) | |||
| : BatchedMatrixMul(handle), | |||
| m_opr(handle->create_operator<MatrixMul>()) {} | |||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| const TensorLayout&) override; | |||
| bool is_thread_safe() const override { return true; } | |||
| class AlgoBase; | |||
| class AlgoBlas; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override { | |||
| return {}; | |||
| } | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) override { | |||
| return nullptr; | |||
| } | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| bool /*reproducible*/) override; | |||
| bool is_thread_safe() const override { return true; } | |||
| const char* get_algorithm_set_name() const override { | |||
| return "ROCM BATCHED MATMUL"; | |||
| } | |||
| private: | |||
| std::unique_ptr<MatrixMul> m_opr; | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| } // namespace rocm | |||
| @@ -0,0 +1,62 @@ | |||
| /** | |||
| * \file dnn/src/rocm/matrix_mul/algos.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/rocm/matrix_mul/algos.h" | |||
| #include "src/common/algo_base.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&blas); | |||
| for (auto&& algo : all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| } | |||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
| MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||
| MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | |||
| const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) | |||
| : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} | |||
| MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr, | |||
| _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) | |||
| : SizeArgs(opr, A.layout, B.layout, C.layout), | |||
| tensor_a{A}, | |||
| tensor_b{B}, | |||
| tensor_c{C}, | |||
| workspace{workspace} {} | |||
| std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
| auto&& param = opr->param(); | |||
| size_t m = layout_a.shape[0], n = layout_b.shape[1], | |||
| k = layout_a.shape[param.transposeA ? 0 : 1]; | |||
| MEGDNN_MARK_USED_VAR(m); | |||
| MEGDNN_MARK_USED_VAR(n); | |||
| MEGDNN_MARK_USED_VAR(k); | |||
| return megdnn_mangle(ssprintf( | |||
| "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " | |||
| "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", | |||
| m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
| layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * \file dnn/src/rocm/matrix_mul/algos.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/rocm/matrix_mul/opr_impl.h" | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| /*! | |||
| * \brief base class for matrix mul algos | |||
| * | |||
| */ | |||
| class MatrixMulForwardImpl::AlgoBase : public Algorithm { | |||
| protected: | |||
| ~AlgoBase() = default; | |||
| public: | |||
| enum class AlgoType : uint32_t { | |||
| ROCM_BLAS, | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||
| struct SizeArgs { | |||
| MatrixMulForwardImpl* opr; | |||
| TensorLayout layout_a, layout_b, layout_c; | |||
| std::string to_string() const; | |||
| SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, | |||
| const TensorLayout& B, const TensorLayout& C); | |||
| bool can_be_treated_as_int8x8x32() const { | |||
| return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | |||
| (layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
| layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) && | |||
| (layout_c.dtype.enumv() == DTypeEnum::Int32 || | |||
| layout_c.dtype.enumv() == DTypeEnum::QuantizedS32) && | |||
| opr->param().format == param::MatrixMul::Format::DEFAULT; | |||
| } | |||
| }; | |||
| struct ExecArgs : public SizeArgs { | |||
| TensorND tensor_a, tensor_b, tensor_c; | |||
| Workspace workspace; | |||
| ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace); | |||
| }; | |||
| virtual bool is_available(const SizeArgs& args) const = 0; | |||
| virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
| virtual void exec(const ExecArgs& args) const = 0; | |||
| bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
| return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
| } | |||
| bool is_available_reproducible( | |||
| const SizeArgs& args, bool reproducible = true, | |||
| size_t limit = std::numeric_limits<size_t>::max()) const { | |||
| return (!reproducible || is_reproducible()) && | |||
| is_available_wk(args, limit); | |||
| } | |||
| AlgoBase& check_workspace(const SizeArgs& args, | |||
| const Workspace& workspace) { | |||
| auto req = get_workspace_in_bytes(args); | |||
| megdnn_assert( | |||
| req <= workspace.size, | |||
| "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", | |||
| name(), req, workspace.size); | |||
| return *this; | |||
| } | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoBlas final : public AlgoBase { | |||
| public: | |||
| AlgoBlas() = default; | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
| return 0_z; | |||
| } | |||
| const char* name() const override { return "BLAS"; } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
| }; | |||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
| private: | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack(); | |||
| AlgoBlas blas; | |||
| std::vector<AlgoBase*> all_algos; | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,162 @@ | |||
| /** | |||
| * \file dnn/src/rocm/matrix_mul/Blas.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/rocm/matrix_mul/algos.h" | |||
| #include "hcc_detail/hcc_defs_prologue.h" | |||
| #include "src/rocm/handle.h" | |||
| #include "src/rocm/utils.h" | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| bool MatrixMulForwardImpl::AlgoBlas::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) | |||
| return false; | |||
| if (args.layout_a.dtype == dtype::Float32() || | |||
| args.layout_a.dtype == dtype::Float16()) { | |||
| return true; | |||
| } else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 || | |||
| args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||
| //! see | |||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||
| bool rocblas_int8x8x32_valid = true; | |||
| rocblas_int8x8x32_valid &= (k % 4 == 0); | |||
| rocblas_int8x8x32_valid &= (!args.opr->param().transposeB || | |||
| args.layout_b.stride[0] % 4 == 0); | |||
| rocblas_int8x8x32_valid &= (!args.opr->param().transposeA || | |||
| args.layout_a.stride[0] % 4 == 0); | |||
| return rocblas_int8x8x32_valid; | |||
| } | |||
| return false; | |||
| } | |||
| void MatrixMulForwardImpl::AlgoBlas::exec(const ExecArgs& args) const { | |||
| auto m = args.layout_c.shape[0], n = args.layout_c.shape[1]; | |||
| auto k = args.layout_a.shape[args.opr->param().transposeA ? 0 : 1]; | |||
| auto&& handle = concrete_handle(args.opr->handle()); | |||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
| auto sgemm = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| rocblas_check(rocblas_sgemm( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, args.tensor_b.ptr<dt_float32>(), | |||
| args.layout_b.stride[0], args.tensor_a.ptr<dt_float32>(), | |||
| args.layout_a.stride[0], zero, args.tensor_c.ptr<dt_float32>(), | |||
| args.layout_c.stride[0])); | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| //! used for FLOAT_IO16xC32, not tested | |||
| auto gemm_ex = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| //! These two arguments for future use, see | |||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| auto gemm_ex_err = rocblas_gemm_ex( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_f16_r, | |||
| args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||
| rocblas_datatype_f16_r, args.layout_a.stride[0], zero, | |||
| args.tensor_c.raw_ptr, rocblas_datatype_f16_r, | |||
| args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||
| rocblas_datatype_f16_r, args.layout_c.stride[0], | |||
| rocblas_datatype_f32_r, rocblas_gemm_algo_standard, | |||
| solution_index, flags, &ws_size, nullptr); | |||
| rocblas_check(gemm_ex_err); | |||
| MEGDNN_MARK_USED_VAR(ws_size); | |||
| }; | |||
| auto hgemm = [&]() { | |||
| auto one_half = handle->one_device_h(); | |||
| auto zero_half = handle->zero_device_h(); | |||
| auto hgemm_err = rocblas_hgemm( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
| static_cast<const rocblas_half*>(args.tensor_b.raw_ptr), | |||
| args.layout_b.stride[0], | |||
| static_cast<const rocblas_half*>(args.tensor_a.raw_ptr), | |||
| args.layout_a.stride[0], | |||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||
| static_cast<rocblas_half*>(args.tensor_c.raw_ptr), | |||
| args.layout_c.stride[0]); | |||
| rocblas_check(hgemm_err); | |||
| }; | |||
| #endif | |||
| if (args.opr->param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
| if (args.layout_a.dtype == dtype::Float32()) { | |||
| sgemm(); | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else { | |||
| megdnn_assert(args.layout_a.dtype == dtype::Float16(), | |||
| "invalid matmul data type"); | |||
| hgemm(); | |||
| } | |||
| #endif | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else if (args.opr->param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
| megdnn_assert(args.layout_b.dtype == dtype::Float16() && | |||
| args.layout_c.dtype == dtype::Float16() && | |||
| args.layout_a.dtype == dtype::Float16(), | |||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
| "B, C are all Float16"); | |||
| gemm_ex(); | |||
| } | |||
| #endif | |||
| else { | |||
| megdnn_assert(args.can_be_treated_as_int8x8x32()); | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| auto zero = handle->zero_device_i32(); | |||
| auto one = handle->one_device_i32(); | |||
| rocblas_check(rocblas_gemm_ex( | |||
| rocblas_handle_, | |||
| args.opr->param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| args.opr->param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, args.tensor_b.raw_ptr, rocblas_datatype_i8_r, | |||
| args.layout_b.stride[0], args.tensor_a.raw_ptr, | |||
| rocblas_datatype_i8_r, args.layout_a.stride[0], zero, | |||
| args.tensor_c.raw_ptr, rocblas_datatype_i32_r, | |||
| args.layout_c.stride[0], args.tensor_c.raw_ptr, | |||
| rocblas_datatype_i32_r, args.layout_c.stride[0], | |||
| rocblas_datatype_i32_r, rocblas_gemm_algo_standard, | |||
| solution_index, flags, &ws_size, nullptr)); | |||
| MEGDNN_MARK_USED_VAR(ws_size); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -13,147 +13,53 @@ | |||
| #include "src/rocm/utils.h" | |||
| #include "src/rocm/handle.h" | |||
| #include "./algos.h" | |||
| #include "src/common/algo_chooser.h" | |||
| namespace megdnn { | |||
| namespace rocm { | |||
| using namespace megdnn; | |||
| using namespace rocm; | |||
| void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||
| _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
| auto m = C.layout.shape[0], n = C.layout.shape[1]; | |||
| auto k = A.layout.shape[param().transposeA ? 0 : 1]; | |||
| auto handle = concrete_handle(this->handle()); | |||
| auto rocblas_handle_ = handle->get_rocblas_handle(); | |||
| auto sgemm = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| rocblas_check(rocblas_sgemm( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, B.ptr<dt_float32>(), B.layout.stride[0], | |||
| A.ptr<dt_float32>(), A.layout.stride[0], zero, | |||
| C.ptr<dt_float32>(), C.layout.stride[0])); | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| //! used for FLOAT_IO16xC32, not tested | |||
| auto gemm_ex = [&]() { | |||
| auto zero = handle->zero_device(); | |||
| auto one = handle->one_device(); | |||
| //! These two arguments for future use, see | |||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| auto gemm_ex_err = rocblas_gemm_ex( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r, | |||
| B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r, | |||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r, | |||
| C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r, | |||
| C.layout.stride[0], rocblas_datatype_f32_r, | |||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
| nullptr); | |||
| rocblas_check(gemm_ex_err); | |||
| }; | |||
| auto hgemm = [&]() { | |||
| auto one_half = handle->one_device_h(); | |||
| auto zero_half = handle->zero_device_h(); | |||
| auto hgemm_err = rocblas_hgemm( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, reinterpret_cast<const rocblas_half*>(one_half), | |||
| static_cast<const rocblas_half*>(B.raw_ptr), B.layout.stride[0], | |||
| static_cast<const rocblas_half*>(A.raw_ptr), A.layout.stride[0], | |||
| reinterpret_cast<const rocblas_half*>(zero_half), | |||
| static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[0]); | |||
| rocblas_check(hgemm_err); | |||
| }; | |||
| #endif | |||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||
| MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | |||
| } | |||
| if (param().compute_mode == Param::ComputeMode::DEFAULT) { | |||
| if (A.layout.dtype == dtype::Float32()) { | |||
| sgemm(); | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else { | |||
| megdnn_assert(A.layout.dtype == dtype::Float16(), | |||
| "invalid matmul data type"); | |||
| hgemm(); | |||
| } | |||
| #endif | |||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| if (sm_algo_pack.blas.is_available_reproducible( | |||
| args, reproducible, workspace_limit_in_bytes)) { | |||
| return &sm_algo_pack.blas; | |||
| } | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| else if (param().compute_mode == Param::ComputeMode::FLOAT32) { | |||
| megdnn_assert(B.layout.dtype == dtype::Float16() && | |||
| C.layout.dtype == dtype::Float16() && | |||
| A.layout.dtype == dtype::Float16(), | |||
| "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " | |||
| "B, C are all Float16"); | |||
| gemm_ex(); | |||
| } | |||
| #endif | |||
| else if (A.layout.dtype == dtype::Int8() && | |||
| B.layout.dtype == dtype::Int8() && | |||
| C.layout.dtype == dtype::Int32()) { | |||
| //! see | |||
| //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 | |||
| bool rocblas_int8x8x32_valid = true; | |||
| rocblas_int8x8x32_valid &= (k % 4 == 0); | |||
| rocblas_int8x8x32_valid &= | |||
| (!param().transposeB || B.layout.stride[0] % 4 == 0); | |||
| rocblas_int8x8x32_valid &= | |||
| (!param().transposeA || A.layout.stride[0] % 4 == 0); | |||
| megdnn_assert(rocblas_int8x8x32_valid, | |||
| "rocblas int8x8x32 matmul requires K must be a multiple " | |||
| "of 4, and/or LDA/LDB based on transpose mode" | |||
| "get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu", | |||
| k, param().transposeB, B.layout.stride[0], | |||
| param().transposeA, A.layout.stride[0]); | |||
| int32_t solution_index = 0; | |||
| uint32_t flags = 1; | |||
| size_t ws_size = 0; | |||
| auto zero = handle->zero_device_i32(); | |||
| auto one = handle->one_device_i32(); | |||
| rocblas_check(rocblas_gemm_ex( | |||
| rocblas_handle_, | |||
| param().transposeB ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| param().transposeA ? rocblas_operation_transpose | |||
| : rocblas_operation_none, | |||
| n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, | |||
| B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r, | |||
| A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, | |||
| C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r, | |||
| C.layout.stride[0], rocblas_datatype_i32_r, | |||
| rocblas_gemm_algo_standard, solution_index, flags, &ws_size, | |||
| nullptr)); | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "matrix mul forward"); | |||
| } else { | |||
| megdnn_assert((A.layout.dtype == dtype::Int8() && | |||
| B.layout.dtype == dtype::Int8() && | |||
| C.layout.dtype == dtype::Int16()), | |||
| "invalid matmul data type"); | |||
| megdnn_throw("cuda matmul does not support INT8x8x16 now"); | |||
| return megdnn::get_usable_algo<MatrixMulForwardImpl>( | |||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||
| "matrix mul forward"); | |||
| } | |||
| } | |||
| } // namespace rocm | |||
| } // namespace megdnn | |||
| size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_algorithm(this, A, B, C)->get_workspace_in_bytes(args); | |||
| } | |||
| void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
| AlgoBase::ExecArgs args(this, A, B, C, workspace); | |||
| auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); | |||
| algo->check_workspace(args, workspace).exec(args); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -20,29 +20,32 @@ public: | |||
| void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| const TensorLayout&) override; | |||
| bool is_thread_safe() const override { return true; } | |||
| class AlgoBase; | |||
| class AlgoBlas; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override { | |||
| return {}; | |||
| } | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override { | |||
| return nullptr; | |||
| } | |||
| bool /*reproducible*/) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "ROCM MATMUL"; | |||
| } | |||
| static AlgoPack sm_algo_pack; | |||
| }; | |||
| } // namespace rocm | |||
| @@ -46,6 +46,37 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, MATRIX_MUL_NAIVE) { | |||
| Checker<MatrixMul> checker(handle()); | |||
| checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_NAIVE")); | |||
| using Param = MatrixMul::Param; | |||
| auto args = matrix_mul::get_matmul_args(); | |||
| for (auto arg : args) { | |||
| auto m = arg.m, n = arg.n, k = arg.k; | |||
| auto mask = arg.mask; | |||
| Param param; | |||
| param.transposeA = mask & 1; | |||
| param.transposeB = mask & 2; | |||
| TensorShape AS, BS, CS; | |||
| if (param.transposeA) | |||
| AS = TensorShape{k, m}; | |||
| else | |||
| AS = TensorShape{m, k}; | |||
| if (param.transposeB) | |||
| BS = TensorShape{n, k}; | |||
| else | |||
| BS = TensorShape{k, n}; | |||
| CS = TensorShape{m, n}; | |||
| TensorLayout AL, BL, CL; | |||
| AL = TensorLayout(AS, dtype::Float32()); | |||
| BL = TensorLayout(BS, dtype::Float32()); | |||
| CL = TensorLayout(CS, dtype::Float32()); | |||
| checker.set_param(param); | |||
| checker.execl({AL, BL, CL}); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, BATCHED_MATRIX_MUL) { | |||
| Checker<BatchedMatrixMul> checker(handle()); | |||
| @@ -232,7 +232,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||
| 2, 5, 3, 3, 7, 4, -7, 1, | |||
| -5, 7, -4, -1, -1, 2, 4, 1, | |||
| 7, 2, -6, -2, -6, 3, 4, 4, | |||
| -2, 2, 3, 0, 6, 5, 3, 4, | |||
| -2, 2, 3, 0, 6, 5, 3, 4, | |||
| -1, -1, -5, 5, 2, 5, 1, 4, | |||
| 6, 2, 0, 0, 3, 2, 2, 1, | |||
| -4, -3, 7, 5, 0, 3, 2, 3}), | |||
| @@ -243,7 +243,7 @@ TEST_F(NAIVE, MATRIX_MUL_QUANTIZEDS4_4x4x16) { | |||
| 3, -1, 2, 2, 7, 3, 6, 0, | |||
| 5, 4, 0, 2, 2, 3, 3, 2, | |||
| 1, -8, -7, -6, 0, -5, -4, 4, | |||
| -3, 7, 1, 6, -2, 2, -1, 5, | |||
| -3, 7, 1, 6, -2, 2, -1, 5, | |||
| 2, 0, 7, 6, 5, 4, 3, 2, | |||
| 0, 0, 1, 0, 5, 2, 2, 6}), | |||
| {}}, | |||