GitOrigin-RevId: 51702c4e79
tags/v1.3.0
| @@ -9,7 +9,7 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \ | |||||
| ../src/cuda/elemwise_multi_type/kimpl | ../src/cuda/elemwise_multi_type/kimpl | ||||
| CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl | CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl | ||||
| CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl | |||||
| CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl ../src/cuda/matrix_mul/fp32_simt_gemv/kimpl | |||||
| all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) | all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) | ||||
| @@ -51,4 +51,7 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) | |||||
| ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py | ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py | ||||
| ./$^ $@ | ./$^ $@ | ||||
| ../src/cuda/matrix_mul/fp32_simt_gemv/kimpl: gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| ./$^ $@ | |||||
| .PHONY: all | .PHONY: all | ||||
| @@ -33,6 +33,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| all_algos.push_back(&bfloat16); | all_algos.push_back(&bfloat16); | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | |||||
| fill_cutlass_algos(); | fill_cutlass_algos(); | ||||
| for (auto&& algo : simt_float32) { | for (auto&& algo : simt_float32) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| @@ -40,12 +41,17 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||||
| for (auto&& algo : simt_float32_split_k) { | for (auto&& algo : simt_float32_split_k) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| } | } | ||||
| for (auto&& algo : simt_float32_gemv_batched_strided) { | |||||
| all_algos.push_back(&algo); | |||||
| } | |||||
| #endif | |||||
| for (auto&& algo : all_algos) { | for (auto&& algo : all_algos) { | ||||
| m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
| } | } | ||||
| } | } | ||||
| #if CUDA_VERSION >= 9020 | |||||
| void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | ||||
| using AlgoParam = AlgoFloat32SIMT::AlgoParam; | using AlgoParam = AlgoFloat32SIMT::AlgoParam; | ||||
| simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); | ||||
| @@ -82,7 +88,11 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { | |||||
| simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8}); | simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8}); | ||||
| simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8}); | simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8}); | ||||
| simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8}); | simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8}); | ||||
| simt_float32_gemv_batched_strided.emplace_back(128); | |||||
| simt_float32_gemv_batched_strided.emplace_back(64); | |||||
| simt_float32_gemv_batched_strided.emplace_back(32); | |||||
| } | } | ||||
| #endif | |||||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | ||||
| @@ -42,8 +42,11 @@ public: | |||||
| CUDA_CUBLASLT, | CUDA_CUBLASLT, | ||||
| CUDA_NAIVE, | CUDA_NAIVE, | ||||
| CUDA_BFLOAT16, | CUDA_BFLOAT16, | ||||
| #if CUDA_VERSION >= 9020 | |||||
| CUDA_FLOAT32_SIMT, | CUDA_FLOAT32_SIMT, | ||||
| CUDA_FLOAT32_SIMT_SPLIT_K, | CUDA_FLOAT32_SIMT_SPLIT_K, | ||||
| CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, | |||||
| #endif | |||||
| }; | }; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
| @@ -167,6 +170,7 @@ private: | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | |||||
| class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { | ||||
| public: | public: | ||||
| struct AlgoParam { | struct AlgoParam { | ||||
| @@ -224,6 +228,32 @@ private: | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| class MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided final | |||||
| : public AlgoBase { | |||||
| public: | |||||
| AlgoFloat32SIMTGemvBatchedStrided(int threadblock_n) | |||||
| : m_threadblock_n{threadblock_n}, | |||||
| m_name{ssprintf("CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d", | |||||
| m_threadblock_n)} {} | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_reproducible() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED) | |||||
| std::string param() const override { | |||||
| std::string ret; | |||||
| serialize_write_pod(m_threadblock_n, ret); | |||||
| return ret; | |||||
| } | |||||
| private: | |||||
| int m_threadblock_n; | |||||
| std::string m_name; | |||||
| }; | |||||
| #endif | |||||
| class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | ||||
| private: | private: | ||||
| AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
| @@ -241,8 +271,12 @@ public: | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| AlgoBFloat16 bfloat16; | AlgoBFloat16 bfloat16; | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | |||||
| std::vector<AlgoFloat32SIMT> simt_float32; | std::vector<AlgoFloat32SIMT> simt_float32; | ||||
| std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k; | ||||
| std::vector<AlgoFloat32SIMTGemvBatchedStrided> | |||||
| simt_float32_gemv_batched_strided; | |||||
| #endif | |||||
| std::vector<AlgoBase*> all_algos; | std::vector<AlgoBase*> all_algos; | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | ||||
| @@ -15,20 +15,17 @@ | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace cutlass_wrapper; | using namespace cutlass_wrapper; | ||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| #if CUDA_VERSION >= 9200 | |||||
| return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | ||||
| args.layout_a.dtype == dtype::Float32() && | args.layout_a.dtype == dtype::Float32() && | ||||
| args.layout_b.dtype == dtype::Float32() && | args.layout_b.dtype == dtype::Float32() && | ||||
| args.layout_c.dtype == dtype::Float32(); | args.layout_c.dtype == dtype::Float32(); | ||||
| #else | |||||
| return false; | |||||
| #endif | |||||
| } | } | ||||
| size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( | ||||
| @@ -69,5 +66,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { | |||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream); | stream); | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/matrix_mul/algos.h" | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::is_available( | |||||
| const SizeArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| bool ta = param.transposeA, tb = param.transposeB; | |||||
| return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
| args.layout_a.dtype == dtype::Float32() && | |||||
| args.layout_b.dtype == dtype::Float32() && | |||||
| args.layout_c.dtype == dtype::Float32() && ((!ta) && (!tb)); | |||||
| } | |||||
| size_t | |||||
| MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::get_workspace_in_bytes( | |||||
| const SizeArgs& /* args */) const { | |||||
| return 0; | |||||
| } | |||||
| void MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::exec( | |||||
| const ExecArgs& args) const { | |||||
| size_t lda = args.tensor_a.layout.stride[0], | |||||
| ldb = args.tensor_b.layout.stride[0], | |||||
| ldc = args.tensor_c.layout.stride[0]; | |||||
| auto&& param = args.opr->param(); | |||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | |||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | |||||
| // m is always 1 in gemv batched strided case | |||||
| BatchedGemmCoord problem_size{1, n, k, m}; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | |||||
| return cutlass_matrix_mul_float32_simt_gemv_batched_strided( | |||||
| args.tensor_a.ptr<dt_float32>(), lda, lda, | |||||
| args.tensor_b.ptr<dt_float32>(), ldb, 0, | |||||
| args.tensor_c.ptr<dt_float32>(), ldc, ldc, problem_size, | |||||
| m_threadblock_n, stream); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -15,6 +15,7 @@ | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #if CUDA_VERSION >= 9020 | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| using namespace cutlass_wrapper; | using namespace cutlass_wrapper; | ||||
| @@ -22,12 +23,12 @@ using namespace cutlass_wrapper; | |||||
| bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
| int n = args.layout_c.shape[1], | |||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
| return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | ||||
| args.layout_a.dtype == dtype::Float32() && | args.layout_a.dtype == dtype::Float32() && | ||||
| args.layout_b.dtype == dtype::Float32() && | args.layout_b.dtype == dtype::Float32() && | ||||
| args.layout_c.dtype == dtype::Float32() && k > std::max(m, n); | |||||
| args.layout_c.dtype == dtype::Float32() && k > n; | |||||
| } | } | ||||
| size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | ||||
| @@ -38,7 +39,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
| int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | ||||
| k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
| GemmCoord problem_size{m, n, k}; | GemmCoord problem_size{m, n, k}; | ||||
| int split_k_slices = k / std::max(m, n); | |||||
| int split_k_slices = k / n; | |||||
| return cutlass_matrix_mul_float32_simt_get_workspace_size( | return cutlass_matrix_mul_float32_simt_get_workspace_size( | ||||
| param.transposeA, lda, param.transposeB, ldb, ldc, problem_size, | param.transposeA, lda, param.transposeB, ldb, ldc, problem_size, | ||||
| 1.f, 0.f, | 1.f, 0.f, | ||||
| @@ -58,7 +59,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
| int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
| k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
| GemmCoord problem_size{m, n, k}; | GemmCoord problem_size{m, n, k}; | ||||
| int split_k_slices = k / std::max(m, n); | |||||
| int split_k_slices = k / n; | |||||
| auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
| int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
| return cutlass_matrix_mul_float32_simt( | return cutlass_matrix_mul_float32_simt( | ||||
| @@ -72,5 +73,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
| m_algo_param.warp_k}, | m_algo_param.warp_k}, | ||||
| stream, split_k_slices); | stream, split_k_slices); | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,16 +10,16 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| // ignore warning of cutlass | // ignore warning of cutlass | ||||
| #include "cuda.h" | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| #pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | #pragma GCC diagnostic ignored "-Wstrict-aliasing" | ||||
| #include "cuda.h" | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| #include "cutlass/gemm/device/gemm.h" | #include "cutlass/gemm/device/gemm.h" | ||||
| #include "cutlass/gemm/device/gemm_splitk_parallel.h" | #include "cutlass/gemm/device/gemm_splitk_parallel.h" | ||||
| #endif | |||||
| #include "cutlass/gemm/kernel/default_gemv.h" | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | #include "src/common/opr_param_defs_enumv.cuh" | ||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | ||||
| #pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
| @@ -54,18 +54,6 @@ using namespace cutlass_wrapper; | |||||
| threadblock_shape.m(), threadblock_shape.n(), \ | threadblock_shape.m(), threadblock_shape.n(), \ | ||||
| threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ | ||||
| warp_shape.k()); | warp_shape.k()); | ||||
| #if __CUDACC_VER_MAJOR__ < 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2) | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | |||||
| const float* /* d_A */, bool /* transpose_A */, size_t /* lda */, | |||||
| const float* /* d_B */, bool /* transpose_B */, size_t /* ldb */, | |||||
| float* /* d_C */, size_t /* ldc */, int* /* workspace */, | |||||
| GemmCoord const& /* problem_size */, float /* alpha */, | |||||
| float /* beta */, const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, cudaStream_t /* stream */, | |||||
| int /* split_k_slices */) {} | |||||
| #else | |||||
| void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | ||||
| const float* d_A, bool transpose_A, size_t lda, const float* d_B, | const float* d_A, bool transpose_A, size_t lda, const float* d_B, | ||||
| bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, | ||||
| @@ -162,20 +150,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| #if __CUDACC_VER_MAJOR__ < 9 || \ | |||||
| (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2) | |||||
| size_t megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_matrix_mul_float32_simt_get_workspace_size( | |||||
| bool /* transpose_A */, size_t /* lda */, | |||||
| bool /* transpose_B */, size_t /* ldb */, size_t /* ldc */, | |||||
| GemmCoord const& /* problem_size */, float /* alpha */, | |||||
| float /* beta */, const GemmCoord& /* threadblock_shape */, | |||||
| const GemmCoord& /* warp_shape */, int /* split_k_slices */) { | |||||
| return 0; | |||||
| } | |||||
| #else | |||||
| size_t megdnn::cuda::cutlass_wrapper:: | size_t megdnn::cuda::cutlass_wrapper:: | ||||
| cutlass_matrix_mul_float32_simt_get_workspace_size( | cutlass_matrix_mul_float32_simt_get_workspace_size( | ||||
| bool transpose_A, size_t lda, bool transpose_B, size_t ldb, | bool transpose_A, size_t lda, bool transpose_B, size_t ldb, | ||||
| @@ -294,7 +269,86 @@ size_t megdnn::cuda::cutlass_wrapper:: | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| #undef DISPATCH | |||||
| /* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided | |||||
| * =========== | |||||
| */ | |||||
| #define DISPATCH(cb) \ | |||||
| cb(128, 4, 4); \ | |||||
| cb(128, 4, 2); \ | |||||
| cb(128, 4, 1); \ | |||||
| cb(128, 2, 4); \ | |||||
| cb(128, 1, 4); \ | |||||
| cb(128, 2, 2); \ | |||||
| cb(128, 1, 2); \ | |||||
| cb(128, 2, 1); \ | |||||
| cb(128, 1, 1); \ | |||||
| cb(64, 4, 4); \ | |||||
| cb(64, 4, 2); \ | |||||
| cb(64, 4, 1); \ | |||||
| cb(64, 2, 4); \ | |||||
| cb(64, 1, 4); \ | |||||
| cb(64, 2, 2); \ | |||||
| cb(64, 1, 2); \ | |||||
| cb(64, 2, 1); \ | |||||
| cb(64, 1, 1); \ | |||||
| cb(32, 4, 4); \ | |||||
| cb(32, 4, 2); \ | |||||
| cb(32, 4, 1); \ | |||||
| cb(32, 2, 4); \ | |||||
| cb(32, 1, 4); \ | |||||
| cb(32, 2, 2); \ | |||||
| cb(32, 1, 2); \ | |||||
| cb(32, 2, 1); \ | |||||
| cb(32, 1, 1); \ | |||||
| megdnn_assert(false, \ | |||||
| "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ | |||||
| problem_size.batch(), problem_size.m(), problem_size.k(), \ | |||||
| problem_size.batch(), problem_size.k(), problem_size.n()); | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_matrix_mul_float32_simt_gemv_batched_strided( | |||||
| const float* d_A, size_t lda, size_t batch_stride_a, | |||||
| const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, | |||||
| size_t ldc, size_t batch_stride_c, | |||||
| BatchedGemmCoord const& problem_size, int threadblock_n, | |||||
| cudaStream_t stream) { | |||||
| int LDG_K, LDG_N; | |||||
| if (lda % 4 == 0) | |||||
| LDG_K = 4; | |||||
| else if (lda % 2 == 0) | |||||
| LDG_K = 2; | |||||
| else | |||||
| LDG_K = 1; | |||||
| if (ldb % 4 == 0) | |||||
| LDG_N = 4; | |||||
| else if (ldb % 2 == 0) | |||||
| LDG_N = 2; | |||||
| else | |||||
| LDG_N = 1; | |||||
| #define cb(threadblock_n_, LDG_K_, LDG_N_) \ | |||||
| if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \ | |||||
| LDG_N == LDG_N_) { \ | |||||
| using ThreadBlockShape = \ | |||||
| cutlass::gemm::GemmShape<1, threadblock_n_, \ | |||||
| (256 * LDG_K_) / \ | |||||
| (threadblock_n_ / LDG_N_)>; \ | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ | |||||
| ThreadBlockShape, ThreadShape, float, \ | |||||
| cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \ | |||||
| float, cutlass::layout::RowMajor>; \ | |||||
| return cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( \ | |||||
| problem_size, d_A, lda, batch_stride_a, d_B, ldb, \ | |||||
| batch_stride_b, d_C, ldc, batch_stride_c, stream); \ | |||||
| } | |||||
| DISPATCH(cb) | |||||
| #undef cb | |||||
| } | |||||
| #undef DISPATCH | #undef DISPATCH | ||||
| #endif | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -13,11 +13,13 @@ | |||||
| #include "cutlass/gemm/gemm.h" | #include "cutlass/gemm/gemm.h" | ||||
| #include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
| #if CUDA_VERSION >= 9020 | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| namespace cutlass_wrapper { | namespace cutlass_wrapper { | ||||
| using GemmCoord = cutlass::gemm::GemmCoord; | using GemmCoord = cutlass::gemm::GemmCoord; | ||||
| using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; | |||||
| template <typename Gemm> | template <typename Gemm> | ||||
| void cutlass_matrix_mul_wrapper( | void cutlass_matrix_mul_wrapper( | ||||
| @@ -38,10 +40,26 @@ void cutlass_matrix_mul_float32_simt( | |||||
| size_t cutlass_matrix_mul_float32_simt_get_workspace_size( | size_t cutlass_matrix_mul_float32_simt_get_workspace_size( | ||||
| bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc, | bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc, | ||||
| GemmCoord const& problem_size, float alpha, float beta, | GemmCoord const& problem_size, float alpha, float beta, | ||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, int split_k_slices = 1); | |||||
| const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, | |||||
| int split_k_slices = 1); | |||||
| template <typename GemvKernel> | |||||
| void cutlass_vector_matrix_mul_batched_strided_wrapper( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, | |||||
| size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, | |||||
| size_t ldb, size_t batch_stride_b, typename GemvKernel::ElementCD* d_C, | |||||
| size_t ldc, size_t batch_stride_c, cudaStream_t stream); | |||||
| void cutlass_matrix_mul_float32_simt_gemv_batched_strided( | |||||
| const float* d_A, size_t lda, size_t batch_stride_a, const float* d_B, | |||||
| size_t ldb, size_t batch_stride_b, float* d_C, size_t ldc, | |||||
| size_t batch_stride_c, BatchedGemmCoord const& problem_size, | |||||
| int threadblock_n, cudaStream_t stream); | |||||
| } // namespace cutlass_wrapper | } // namespace cutlass_wrapper | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 2>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 128>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 64>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,26 @@ | |||||
| #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) | |||||
| // generated by gen_cutlass_gemv_batched_strided_kern_impls.py | |||||
| // ignore warning of cutlass | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
| #pragma GCC diagnostic ignored "-Wstrict-aliasing" | |||||
| #include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" | |||||
| using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>; | |||||
| using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; | |||||
| using GemvKernel = cutlass::gemm::kernel::DefaultGemv< | |||||
| ThreadBlockShape, | |||||
| ThreadShape, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor, | |||||
| float, cutlass::layout::RowMajor>; | |||||
| template void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, | |||||
| const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, | |||||
| cudaStream_t stream); | |||||
| #pragma GCC diagnostic pop | |||||
| #endif | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "cutlass/gemm/kernel/default_gemv.h" | |||||
| #include "cutlass/gemm/kernel/gemv_batched_strided.h" | |||||
| #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" | |||||
| #include "src/cuda/query_blocksize.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace cutlass_wrapper; | |||||
| template <typename GemvKernel> | |||||
| void megdnn::cuda::cutlass_wrapper:: | |||||
| cutlass_vector_matrix_mul_batched_strided_wrapper( | |||||
| BatchedGemmCoord const& problem_size, | |||||
| const typename GemvKernel::ElementA* d_A, size_t lda, | |||||
| size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, | |||||
| size_t ldb, size_t batch_stride_b, | |||||
| typename GemvKernel::ElementCD* d_C, size_t ldc, | |||||
| size_t batch_stride_c, cudaStream_t stream) { | |||||
| typename GemvKernel::IteratorA::TensorRef tensor_a{ | |||||
| const_cast<typename GemvKernel::ElementA*>(d_A), | |||||
| typename GemvKernel::LayoutA{static_cast<int>(lda)}}; | |||||
| typename GemvKernel::IteratorB::TensorRef tensor_b{ | |||||
| const_cast<typename GemvKernel::ElementB*>(d_B), | |||||
| typename GemvKernel::LayoutB{static_cast<int>(ldb)}}; | |||||
| typename GemvKernel::IteratorCD::TensorRef tensor_c{ | |||||
| d_C, typename GemvKernel::LayoutCD{static_cast<int>(ldc)}}; | |||||
| static int constexpr kThreadsPerN = GemvKernel::Core::kThreadsPerN; | |||||
| static int constexpr kThreadsPerK = GemvKernel::Core::kThreadsPerK; | |||||
| void (*kern)(BatchedGemmCoord, typename GemvKernel::IteratorA::TensorRef, | |||||
| typename GemvKernel::IteratorA::TensorRef::LongIndex, | |||||
| typename GemvKernel::IteratorB::TensorRef, | |||||
| typename GemvKernel::IteratorB::TensorRef::LongIndex, | |||||
| typename GemvKernel::IteratorCD::TensorRef, | |||||
| typename GemvKernel::IteratorCD::TensorRef::LongIndex); | |||||
| kern = cutlass::gemm::kernel::GemvBatchedStrided<GemvKernel>; | |||||
| // int nr_threads = static_cast<int>( | |||||
| // query_blocksize_for_kernel(reinterpret_cast<const void*>(kern))); | |||||
| // nr_threads = std::max(nr_threads, kThreadsPerN); | |||||
| // megdnn_assert(nr_threads % kThreadsPerN == 0); | |||||
| // int batch = nr_threads / kThreadsPerN; | |||||
| // batch = std::min(batch, problem_size.batch()); | |||||
| auto tile_size = BatchedGemmCoord(GemvKernel::ThreadBlockShape::kM, | |||||
| GemvKernel::ThreadBlockShape::kN, | |||||
| GemvKernel::ThreadBlockShape::kK, 1); | |||||
| typename GemvKernel::ThreadBlockSwizzle swizzler; | |||||
| auto tiled_shape = swizzler.get_tiled_shape(problem_size, tile_size); | |||||
| dim3 grid = swizzler.get_grid_shape(tiled_shape); | |||||
| dim3 block(kThreadsPerN, kThreadsPerK, 1); | |||||
| int smem_size = | |||||
| int(sizeof(typename GemvKernel::ThreadBlockGemv::SharedStorage)); | |||||
| megdnn_assert(smem_size < (48 << 10)); | |||||
| kern<<<grid, block, smem_size, stream>>>( | |||||
| problem_size, tensor_a, batch_stride_a, tensor_b, batch_stride_b, | |||||
| tensor_c, batch_stride_c); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -41,8 +41,11 @@ public: | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| class AlgoBFloat16; | class AlgoBFloat16; | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 9020 | |||||
| class AlgoFloat32SIMT; | class AlgoFloat32SIMT; | ||||
| class AlgoFloat32SIMTSplitK; | class AlgoFloat32SIMTSplitK; | ||||
| class AlgoFloat32SIMTGemvBatchedStrided; | |||||
| #endif | |||||
| class AlgoPack; | class AlgoPack; | ||||
| static const AlgoPack& algo_pack() { | static const AlgoPack& algo_pack() { | ||||
| @@ -90,7 +90,7 @@ void test_multibatchsize( | |||||
| if (std::regex_match( | if (std::regex_match( | ||||
| i.name.c_str(), | i.name.c_str(), | ||||
| std::regex("(" + std::string(algo) + ")(.*)"))) { | std::regex("(" + std::string(algo) + ")(.*)"))) { | ||||
| opr_reference->execution_policy().algo = i; | |||||
| opr_reference->execution_policy().algo = i.desc; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -119,7 +119,7 @@ void test_multibatchsize( | |||||
| if (std::regex_match( | if (std::regex_match( | ||||
| i.name.c_str(), | i.name.c_str(), | ||||
| std::regex("(" + std::string(algo) + ")(.*)"))) { | std::regex("(" + std::string(algo) + ")(.*)"))) { | ||||
| opr_reference->execution_policy().algo = i; | |||||
| opr_reference->execution_policy().algo = i.desc; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -292,6 +292,30 @@ TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) { | |||||
| [](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; }); | [](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; }); | ||||
| } | } | ||||
| TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE) { | |||||
| auto args = matrix_mul::get_matmul_args_no_mask(); | |||||
| test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), | |||||
| dtype::Float32(), | |||||
| "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128", args, | |||||
| param::MatrixMul::Format::DEFAULT); | |||||
| } | |||||
| TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE) { | |||||
| auto args = matrix_mul::get_matmul_args_no_mask(); | |||||
| test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), | |||||
| dtype::Float32(), | |||||
| "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64", args, | |||||
| param::MatrixMul::Format::DEFAULT); | |||||
| } | |||||
| TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE) { | |||||
| auto args = matrix_mul::get_matmul_args_no_mask(); | |||||
| test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), | |||||
| dtype::Float32(), | |||||
| "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32", args, | |||||
| param::MatrixMul::Format::DEFAULT); | |||||
| } | |||||
| #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ | ||||
| cb(1, 64, 256, 8, 32, 64, 8); \ | cb(1, 64, 256, 8, 32, 64, 8); \ | ||||
| cb(2, 256, 64, 8, 64, 32, 8); \ | cb(2, 256, 64, 8, 64, 32, 8); \ | ||||