You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * \file dnn/src/fallback/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/matrix_mul/algos.h"
  12. #include "megdnn/opr_param_defs.h"
  13. #include "src/fallback/matrix_mul/gemm_impl.h"
  14. #include "src/fallback/matrix_mul/gemv.h"
  15. #include "src/fallback/matrix_mul/generic_strategy.h"
  16. #include "src/naive/matrix_mul/matrix_mul_helper.h"
  17. #include "midout.h"
  18. MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
  19. MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
  20. MIDOUT_DECL(megdnn_fb_matmul_naive)
  21. using namespace megdnn;
  22. using namespace fallback;
  23. /* ===================== F32 8x12x1 algo ===================== */
  24. namespace {
  25. void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) {
  26. MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern, void) {
  27. size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
  28. matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_param.A_type,
  29. kern_param.B_type,
  30. kern_param.C_type);
  31. matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
  32. M, N, K, kern_param.trA, kern_param.trB, strategy)
  33. .execute(kern_param.A<float>(), kern_param.LDA,
  34. kern_param.B<float>(), kern_param.LDB,
  35. kern_param.C<float>(), kern_param.LDC,
  36. kern_param.workspace_ptr);
  37. }
  38. MIDOUT_END();
  39. }
  40. void kern_naive(const MatrixMulImpl::KernParam& kern_param) {
  41. MIDOUT_BEGIN(megdnn_fb_matmul_naive, void) {
  42. size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
  43. size_t LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  44. auto get_pack_size = [kern_param]() -> size_t {
  45. switch (kern_param.format) {
  46. case param::MatrixMul::Format::MK4:
  47. case param::MatrixMul::Format::MK4_DOT:
  48. return 4_z;
  49. case param::MatrixMul::Format::MK8:
  50. return 8_z;
  51. default:
  52. return 1_z;
  53. }
  54. };
  55. size_t pack_size = get_pack_size();
  56. megdnn_assert(
  57. (M % pack_size == 0 && K % pack_size == 0),
  58. "M and N must time of pack_size M: %zu N: %zu pack_size: %zu",
  59. M, N, pack_size);
  60. #define DISPATCH(TA, TB) \
  61. if (kern_param.trA == TA && kern_param.trB == TB) { \
  62. naive::dispatch_ta_tb<TA, TB>( \
  63. kern_param.A_ptr, kern_param.B_ptr, kern_param.C_ptr, \
  64. kern_param.workspace_ptr, M / pack_size, N, K / pack_size, \
  65. LDA, LDB, LDC, kern_param.A_type, kern_param.B_type, \
  66. kern_param.C_type, kern_param.format, \
  67. kern_param.compute_mode); \
  68. return; \
  69. }
  70. DISPATCH(true, true);
  71. DISPATCH(true, false);
  72. DISPATCH(false, true);
  73. DISPATCH(false, false);
  74. #undef DISPATCH
  75. megdnn_assert_internal(0);
  76. }
  77. MIDOUT_END();
  78. }
  79. } // anonymous namespace
  80. ////////////////////// AlgoF32K8x12x1 ///////////////////////////
  81. bool MatrixMulImpl::AlgoF32K8x12x1::usable(
  82. const KernSizeParam& kern_size_param) const {
  83. return kern_size_param.compute_mode ==
  84. param::MatrixMul::ComputeMode::DEFAULT &&
  85. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  86. kern_size_param.B_type == kern_size_param.A_type &&
  87. kern_size_param.C_type == kern_size_param.A_type &&
  88. kern_size_param.A_type == dtype::Float32{};
  89. }
  90. size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
  91. const KernSizeParam& kern_size_param) const {
  92. MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern,
  93. midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
  94. auto M = kern_size_param.M, N = kern_size_param.N,
  95. K = kern_size_param.K;
  96. matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
  97. kern_size_param.B_type,
  98. kern_size_param.C_type);
  99. return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
  100. M, N, K, kern_size_param.trA, kern_size_param.trB,
  101. strategy)
  102. .get_workspace_size();
  103. }
  104. MIDOUT_END();
  105. return 0;
  106. }
  107. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
  108. const KernSizeParam&) const {
  109. return f32_8x12x1_kern;
  110. }
  111. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
  112. 5, matmul::fallback::sgemm_8x12, float,
  113. float, AlgoDataType::FLOAT32, DEFAULT);
  114. /* ===================== gemv algo ===================== */
  115. bool MatrixMulImpl::AlgoGemv::usable(
  116. const KernSizeParam& kern_size_param) const {
  117. return !kern_size_param.trA && !kern_size_param.trB &&
  118. kern_size_param.format ==
  119. param::MatrixMul::Format::DEFAULT &&
  120. kern_size_param.compute_mode ==
  121. param::MatrixMul::ComputeMode::DEFAULT &&
  122. !((kern_size_param.A_type.enumv() ==
  123. kern_size_param.B_type.enumv()) &&
  124. (kern_size_param.A_type.enumv() == DTypeEnum::Int16) &&
  125. (kern_size_param.C_type.enumv() == DTypeEnum::Int32));
  126. }
  127. bool MatrixMulImpl::AlgoGemv::preferred(
  128. const KernSizeParam& kern_size_param) const {
  129. return kern_size_param.M <= 2 &&
  130. kern_size_param.A_type.category() != DTypeCategory::FLOAT;
  131. }
  132. MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern(
  133. const KernSizeParam& kern_size_param) const {
  134. #define DISPATCH(A, C, func, _midout_iv) \
  135. if (kern_size_param.A_type.enumv() == DTypeEnum::A && \
  136. kern_size_param.B_type.enumv() == DTypeEnum::A && \
  137. kern_size_param.C_type.enumv() == DTypeEnum::C && \
  138. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && \
  139. kern_size_param.format == param::MatrixMul::Format::DEFAULT) { \
  140. MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like, \
  141. midout_iv(_midout_iv)) { \
  142. return func; \
  143. } \
  144. MIDOUT_END(); \
  145. }
  146. DISPATCH(Float32, Float32, (gemm_gemv_like<dt_float32, dt_float32>), 0);
  147. MEGDNN_INC_FLOAT16(DISPATCH(Float16, Float16,
  148. (gemm_gemv_like<dt_float16, dt_float16>), 1));
  149. DISPATCH(Int8, Int16, (gemm_gemv_like<dt_int8, dt_int16>), 2);
  150. DISPATCH(Quantized8Asymm, QuantizedS32,
  151. (gemm_gemv_like<dt_uint8, dt_int32, true>), 3);
  152. if (can_be_treated_as_int8x8x32(kern_size_param)) {
  153. MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like, midout_iv(4)) {
  154. return gemm_gemv_like<dt_int8, dt_int32>;
  155. }
  156. MIDOUT_END();
  157. }
  158. #undef DISPATCH
  159. megdnn_assert(0);
  160. }
  161. /* ===================== naive algo ===================== */
  162. bool MatrixMulImpl::AlgoNaive::usable(const KernSizeParam&) const {
  163. return true;
  164. }
  165. bool MatrixMulImpl::AlgoNaive::preferred(const KernSizeParam&) const {
  166. return false;
  167. }
  168. size_t MatrixMulImpl::AlgoNaive::get_workspace(
  169. const KernSizeParam& kern_param) const {
  170. MIDOUT_BEGIN(
  171. megdnn_fb_matmul_naive,
  172. midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
  173. if (kern_param.A_type.enumv() == DTypeEnum::Quantized4Asymm ||
  174. kern_param.A_type.enumv() == DTypeEnum::QuantizedS4) {
  175. size_t ret = 0;
  176. if (kern_param.trA) {
  177. ret += kern_param.LDA * kern_param.K;
  178. } else {
  179. ret += kern_param.LDA * kern_param.M;
  180. }
  181. if (kern_param.trB) {
  182. ret += kern_param.LDB * kern_param.N;
  183. } else {
  184. ret += kern_param.LDB * kern_param.K;
  185. }
  186. return ret;
  187. }
  188. return 0;
  189. }
  190. MIDOUT_END();
  191. }
  192. MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(
  193. const KernSizeParam&) const {
  194. return kern_naive;
  195. }
  196. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台