You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_mul.cpp 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /**
  2. * \file dnn/src/common/matrix_mul.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) {
  15. // Expect that the user specifies output dtype (C), we then do sanity
  16. // check on the dtype supplied by the user. C_dtype and C_dtype2 are the
  17. // expected dtypes. If the user does not specify an output dtype by setting
  18. // C = {}, we deduce one (C_dtype) and return it to the user.
  19. DType C_candi, C_candi2;
  20. if (A.category() == DTypeCategory::FLOAT) {
  21. C_candi = A;
  22. } else if (A.enumv() == DTypeEnum::Int8) {
  23. C_candi = dtype::Int32();
  24. C_candi2 = dtype::Int16();
  25. } else if (A.enumv() == DTypeEnum::Int16) {
  26. C_candi = dtype::Int32();
  27. } else if (A.enumv() == DTypeEnum::QuantizedS8) {
  28. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  29. } else if (A.enumv() == DTypeEnum::Quantized8Asymm) {
  30. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  31. } else if (A.enumv() == DTypeEnum::Quantized4Asymm) {
  32. C_candi = dtype::QuantizedS32(mul_scale(A, B));
  33. } else if (A.enumv() == DTypeEnum::QuantizedS4) {
  34. C_candi = dtype::QuantizedS16(mul_scale(A, B));
  35. }
  36. if (!C.valid()) {
  37. C = C_candi;
  38. }
  39. megdnn_assert(
  40. C.valid() && (C == C_candi || C == C_candi2),
  41. "unsupported MatMul(%s, %s) -> %s", A.name(), B.name(), C.name());
  42. }
  43. void MatrixMulForward::deduce_layout(
  44. const TensorLayout& A, const TensorLayout& B, TensorLayout& C) {
  45. megdnn_assert(
  46. A.dtype.enumv() == B.dtype.enumv(),
  47. "matmul input should be of same dtype, got %s and %s", A.dtype.name(),
  48. B.dtype.name());
  49. deduce_dtype(A.dtype, B.dtype, C.dtype);
  50. size_t A0, A1, B0, B1;
  51. if (param().format == param::MatrixMul::Format::DEFAULT) {
  52. megdnn_assert(
  53. A.ndim == 2 && B.ndim == 2,
  54. "matmul requires input to be 2-dimensional; get: %s %s",
  55. A.TensorShape::to_string().c_str(), B.TensorShape::to_string().c_str());
  56. A0 = A.shape[0];
  57. A1 = A.shape[1];
  58. B0 = B.shape[0];
  59. B1 = B.shape[1];
  60. if (m_param.transposeA)
  61. std::swap(A0, A1);
  62. if (m_param.transposeB)
  63. std::swap(B0, B1);
  64. megdnn_assert(
  65. A1 == B0,
  66. "shape mismatch in matmal: (transposed) A is (%zu,%zu), "
  67. "(transposed) B is (%zu,%zu)",
  68. A0, A1, B0, B1);
  69. C = TensorLayout(TensorShape({A0, B1}), C.dtype);
  70. } else {
  71. auto do_deduce = [&](size_t pack_size) {
  72. megdnn_assert(
  73. A.ndim == 4 && B.ndim == 3,
  74. "matmul requires input dimension to be A(4), B(3); "
  75. "get: %s %s",
  76. A.TensorShape::to_string().c_str(),
  77. B.TensorShape::to_string().c_str());
  78. A0 = A.shape[0];
  79. A1 = A.shape[1];
  80. B0 = B.shape[0];
  81. B1 = B.shape[1];
  82. if (m_param.transposeA)
  83. std::swap(A0, A1);
  84. if (m_param.transposeB)
  85. std::swap(B0, B1);
  86. megdnn_assert(
  87. A1 == B0,
  88. "shape mismatch in matmal: (transposed) A is "
  89. "(%zu,%zu,4,4), "
  90. "(transposed) B is (%zu,%zu,4)",
  91. A0, A1, B0, B1);
  92. C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype);
  93. };
  94. do_deduce(pack_size(param().format));
  95. }
  96. }
  97. void MatrixMulForward::check_exec(
  98. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  99. size_t workspace_in_bytes) {
  100. auto errmsg = [&]() {
  101. std::string msg;
  102. msg.append("A=");
  103. msg.append(A.to_string());
  104. msg.append(", B=");
  105. msg.append(B.to_string());
  106. msg.append(", C=");
  107. msg.append(C.to_string());
  108. msg.append(", transposeA=");
  109. msg.append(std::to_string(param().transposeA));
  110. msg.append(", transposeB=");
  111. msg.append(std::to_string(param().transposeB));
  112. return msg;
  113. };
  114. MEGDNN_MARK_USED_VAR(errmsg);
  115. if (param().format == param::MatrixMul::Format::DEFAULT) {
  116. megdnn_assert_eq_size_t(A.ndim, 2_z);
  117. megdnn_assert_eq_size_t(B.ndim, 2_z);
  118. megdnn_assert_eq_size_t(C.ndim, 2_z);
  119. megdnn_assert(A.stride[1] == 1);
  120. megdnn_assert(A.stride[0] >= static_cast<ptrdiff_t>(A.shape[1]));
  121. megdnn_assert(B.stride[1] == 1);
  122. megdnn_assert(B.stride[0] >= static_cast<ptrdiff_t>(B.shape[1]));
  123. megdnn_assert(C.stride[1] == 1);
  124. megdnn_assert(C.stride[0] >= static_cast<ptrdiff_t>(C.shape[1]));
  125. size_t A0, A1, B0, B1, C0, C1;
  126. A0 = A.shape[0];
  127. A1 = A.shape[1];
  128. B0 = B.shape[0];
  129. B1 = B.shape[1];
  130. C0 = C.shape[0];
  131. C1 = C.shape[1];
  132. if (m_param.transposeA)
  133. std::swap(A0, A1);
  134. if (m_param.transposeB)
  135. std::swap(B0, B1);
  136. megdnn_assert(A0 == C0, "%s", errmsg().c_str());
  137. megdnn_assert(B1 == C1, "%s", errmsg().c_str());
  138. megdnn_assert(A1 == B0, "%s", errmsg().c_str());
  139. } else {
  140. megdnn_assert_eq_size_t(A.ndim, 4_z);
  141. megdnn_assert_eq_size_t(B.ndim, 3_z);
  142. megdnn_assert_eq_size_t(C.ndim, 3_z);
  143. megdnn_assert_contiguous(A);
  144. megdnn_assert_contiguous(B);
  145. megdnn_assert_contiguous(C);
  146. size_t A0, A1, B0, B1, C0, C1;
  147. A0 = A.shape[0];
  148. A1 = A.shape[1];
  149. B0 = B.shape[0];
  150. B1 = B.shape[1];
  151. C0 = C.shape[0];
  152. C1 = C.shape[1];
  153. if (m_param.transposeA)
  154. std::swap(A0, A1);
  155. if (m_param.transposeB)
  156. std::swap(B0, B1);
  157. megdnn_assert(A0 == C0, "%s", errmsg().c_str());
  158. megdnn_assert(B1 == C1, "%s", errmsg().c_str());
  159. megdnn_assert(A1 == B0, "%s", errmsg().c_str());
  160. }
  161. megdnn_assert(A.dtype.enumv() == B.dtype.enumv());
  162. if (A.dtype.category() == DTypeCategory::FLOAT) {
  163. megdnn_assert(A.dtype == C.dtype);
  164. } else if (A.dtype == dtype::Int8()) {
  165. megdnn_assert(C.dtype == dtype::Int16() || C.dtype == dtype::Int32());
  166. } else if (
  167. A.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  168. A.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  169. A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  170. megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32);
  171. } else if (A.dtype.enumv() == DTypeEnum::QuantizedS4) {
  172. megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16);
  173. }
  174. megdnn_assert(
  175. param().compute_mode != Param::ComputeMode::FLOAT32 DNN_INC_FLOAT16(
  176. || A.dtype == dtype::Float16() ||
  177. A.dtype == dtype::BFloat16()),
  178. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  179. "input / output.");
  180. auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C);
  181. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  182. }
  183. size_t MatrixMulForward::pack_size(const Param::Format format) {
  184. switch (format) {
  185. case Param::Format::DEFAULT:
  186. return 1;
  187. case Param::Format::MK4:
  188. return 4;
  189. case Param::Format::MK4_DOT:
  190. return 4;
  191. case Param::Format::MK8:
  192. return 8;
  193. default:
  194. megdnn_throw("Unknown matmul format.");
  195. }
  196. }
  197. } // namespace megdnn
  198. // vim: syntax=cpp.doxygen