You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "matmul.h"
  15. namespace ncnn {
  16. MatMul::MatMul()
  17. {
  18. one_blob_only = false;
  19. support_inplace = false;
  20. }
  21. int MatMul::load_param(const ParamDict& pd)
  22. {
  23. transB = pd.get(0, 0);
  24. return 0;
  25. }
  26. static void transpose(const Mat& X, Mat& XT, const Option& opt)
  27. {
  28. const int w = X.w;
  29. const int h = X.h;
  30. const float* pX = X;
  31. float* pXT = XT;
  32. #pragma omp parallel for num_threads(opt.num_threads)
  33. for (int i = 0; i < w; i++)
  34. {
  35. float* ptr = pXT + i * h;
  36. for (int j = 0; j < h; j++)
  37. {
  38. ptr[j] = pX[j * w + i];
  39. }
  40. }
  41. }
  42. static void matmul_transb(const Mat& A, const Mat& B, Mat& top_blob, const Option& opt)
  43. {
  44. const int M = A.h;
  45. const int K = A.w; // assert A.w == B.w
  46. const int N = B.h;
  47. const float* pA = A;
  48. const float* pB = B;
  49. float* pOut = top_blob;
  50. #pragma omp parallel for num_threads(opt.num_threads)
  51. for (int i = 0; i < M; i++)
  52. {
  53. const float* ptrA = pA + i * K;
  54. float* outptr = pOut + i * N;
  55. for (int j = 0; j < N; j++)
  56. {
  57. const float* ptrB = pB + j * K;
  58. float sum = 0.f;
  59. for (int k = 0; k < K; k++)
  60. {
  61. sum += ptrA[k] * ptrB[k];
  62. }
  63. *outptr++ = sum;
  64. }
  65. }
  66. }
  67. int MatMul::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
  68. {
  69. const Mat& A = bottom_blobs[0];
  70. const Mat& B = bottom_blobs[1];
  71. Mat& top_blob = top_blobs[0];
  72. const int Adims = A.dims;
  73. const int Bdims = B.dims;
  74. const int max_ABdims = std::max(Adims, Bdims);
  75. const size_t elemsize = A.elemsize;
  76. if (Adims == 1 && Bdims == 1)
  77. {
  78. // dot product
  79. top_blob.create(1, elemsize, opt.blob_allocator);
  80. if (top_blob.empty())
  81. return -100;
  82. const int K = A.w; // assert A.w == B.w
  83. const float* ptrA = A;
  84. const float* ptrB = B;
  85. float sum = 0.f;
  86. for (int k = 0; k < K; k++)
  87. {
  88. sum += ptrA[k] * ptrB[k];
  89. }
  90. top_blob[0] = sum;
  91. }
  92. else if (Adims == 2 && Bdims == 2)
  93. {
  94. // matrix multiply
  95. const int M = A.h;
  96. const int N = transB == 0 ? B.w : B.h;
  97. top_blob.create(N, M, elemsize, opt.blob_allocator);
  98. if (top_blob.empty())
  99. return -100;
  100. Mat BT;
  101. if (transB == 0)
  102. {
  103. BT.create(B.h, B.w, elemsize, opt.workspace_allocator);
  104. if (BT.empty())
  105. return -100;
  106. transpose(B, BT, opt);
  107. }
  108. else
  109. {
  110. BT = B;
  111. }
  112. matmul_transb(A, BT, top_blob, opt);
  113. }
  114. else if (Adims == 1 && Bdims == 2)
  115. {
  116. // matrix multiply
  117. const int N = transB == 0 ? B.w : B.h;
  118. Mat top_blob1(N, 1, elemsize, opt.blob_allocator);
  119. if (top_blob1.empty())
  120. return -100;
  121. Mat A1 = A.reshape(A.w, 1);
  122. Mat BT;
  123. if (transB == 0)
  124. {
  125. BT.create(B.h, B.w, elemsize, opt.workspace_allocator);
  126. if (BT.empty())
  127. return -100;
  128. transpose(B, BT, opt);
  129. }
  130. else
  131. {
  132. BT = B;
  133. }
  134. matmul_transb(A1, BT, top_blob1, opt);
  135. top_blob = top_blob1.reshape(N);
  136. }
  137. else if (Adims == 2 && Bdims == 1)
  138. {
  139. // matrix multiply
  140. const int M = A.h;
  141. Mat top_blob1(1, M, elemsize, opt.blob_allocator);
  142. if (top_blob1.empty())
  143. return -100;
  144. Mat BT = B.reshape(B.w, 1);
  145. matmul_transb(A, BT, top_blob1, opt);
  146. top_blob = top_blob1.reshape(M);
  147. }
  148. else if (Adims == 1 && Bdims > 2)
  149. {
  150. // batched matrix multiply
  151. const int N = transB == 0 ? B.w : B.h;
  152. const int batch_size = B.d * B.c;
  153. Mat top_blob1(N, 1, batch_size, elemsize, opt.blob_allocator);
  154. if (top_blob1.empty())
  155. return -100;
  156. Mat A1 = A.reshape(A.w, 1);
  157. Mat B1 = B.reshape(B.w, B.h, batch_size);
  158. for (int p = 0; p < batch_size; p++)
  159. {
  160. Mat BT;
  161. if (transB == 0)
  162. {
  163. BT.create(B.h, B.w, elemsize, opt.workspace_allocator);
  164. if (BT.empty())
  165. return -100;
  166. transpose(B1.channel(p), BT, opt);
  167. }
  168. else
  169. {
  170. BT = B1.channel(p);
  171. }
  172. Mat top_blob1_p = top_blob1.channel(p);
  173. matmul_transb(A1, BT, top_blob1_p, opt);
  174. }
  175. if (Bdims == 3)
  176. top_blob = top_blob1.reshape(N, B.d * B.c);
  177. else
  178. top_blob = top_blob1.reshape(N, B.d, B.c);
  179. }
  180. else if (Adims > 2 && Bdims == 1)
  181. {
  182. // batched matrix multiply
  183. const int M = A.h;
  184. const int batch_size = A.d * A.c;
  185. Mat top_blob1(1, M, batch_size, elemsize, opt.blob_allocator);
  186. if (top_blob1.empty())
  187. return -100;
  188. Mat A1 = A.reshape(A.w, A.h, batch_size);
  189. Mat BT = B.reshape(B.w, 1);
  190. for (int p = 0; p < batch_size; p++)
  191. {
  192. Mat top_blob1_p = top_blob1.channel(p);
  193. matmul_transb(A1.channel(p), BT, top_blob1_p, opt);
  194. }
  195. if (Adims == 3)
  196. top_blob = top_blob1.reshape(M, A.d * A.c);
  197. else
  198. top_blob = top_blob1.reshape(M, A.d, A.c);
  199. }
  200. else if (max_ABdims == 3)
  201. {
  202. Mat A1 = Adims == 2 ? A.reshape(A.w, A.h, 1) : A;
  203. Mat B1 = Bdims == 2 ? B.reshape(B.w, B.h, 1) : B;
  204. const int M = A1.h;
  205. const int N = transB == 0 ? B1.w : B1.h;
  206. const int batch_size = std::max(A1.c, B1.c);
  207. top_blob.create(N, M, batch_size, elemsize, opt.blob_allocator);
  208. if (top_blob.empty())
  209. return -100;
  210. Mat BT0;
  211. if (B1.c == 1)
  212. {
  213. if (transB == 0)
  214. {
  215. BT0.create(B1.h, B1.w, elemsize, opt.workspace_allocator);
  216. if (BT0.empty())
  217. return -100;
  218. transpose(B1.channel(0), BT0, opt);
  219. }
  220. else
  221. {
  222. BT0 = B1.channel(0);
  223. }
  224. }
  225. for (int p = 0; p < batch_size; p++)
  226. {
  227. int Ap = A1.c == 1 ? 0 : p;
  228. int Bp = B1.c == 1 ? 0 : p;
  229. Mat BT;
  230. if (B1.c == 1)
  231. {
  232. BT = BT0;
  233. }
  234. else
  235. {
  236. if (transB == 0)
  237. {
  238. BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator);
  239. if (BT.empty())
  240. return -100;
  241. transpose(B1.channel(Bp), BT, opt);
  242. }
  243. else
  244. {
  245. BT = B1.channel(Bp);
  246. }
  247. }
  248. Mat top_blob_p = top_blob.channel(p);
  249. matmul_transb(A1.channel(Ap), BT, top_blob_p, opt);
  250. }
  251. }
  252. else if (max_ABdims == 4)
  253. {
  254. Mat A1 = Adims == 3 ? A.reshape(A.w, A.h, A.c, 1) : A;
  255. Mat B1 = Bdims == 3 ? B.reshape(B.w, B.h, B.c, 1) : B;
  256. const int M = A1.h;
  257. const int N = transB == 0 ? B1.w : B1.h;
  258. const int batch_size_d = std::max(A1.d, B1.d);
  259. const int batch_size_c = std::max(A1.c, B1.c);
  260. top_blob.create(N, M, batch_size_d, batch_size_c, elemsize, opt.blob_allocator);
  261. if (top_blob.empty())
  262. return -100;
  263. Mat BT00;
  264. if (B1.d == 1 && B1.c == 1)
  265. {
  266. if (transB == 0)
  267. {
  268. BT00.create(B1.h, B1.w, elemsize, opt.workspace_allocator);
  269. if (BT00.empty())
  270. return -100;
  271. transpose(B1.channel(0).depth(0), BT00, opt);
  272. }
  273. else
  274. {
  275. BT00 = B1.channel(0).depth(0);
  276. }
  277. }
  278. for (int p = 0; p < batch_size_c; p++)
  279. {
  280. int Ap = A1.c == 1 ? 0 : p;
  281. int Bp = B1.c == 1 ? 0 : p;
  282. Mat BT0x;
  283. if (B1.d == 1 && B1.c != 1)
  284. {
  285. if (transB == 0)
  286. {
  287. BT0x.create(B1.h, B1.w, elemsize, opt.workspace_allocator);
  288. if (BT0x.empty())
  289. return -100;
  290. transpose(B1.channel(Bp).depth(0), BT0x, opt);
  291. }
  292. else
  293. {
  294. BT0x = B1.channel(Bp).depth(0);
  295. }
  296. }
  297. for (int q = 0; q < batch_size_d; q++)
  298. {
  299. int Ad = A1.d == 1 ? 0 : q;
  300. int Bd = B1.d == 1 ? 0 : q;
  301. Mat BT;
  302. if (B1.d == 1 && B1.c == 1)
  303. {
  304. BT = BT00;
  305. }
  306. else if (B1.d == 1 && B1.c != 1)
  307. {
  308. BT = BT0x;
  309. }
  310. else
  311. {
  312. if (transB == 0)
  313. {
  314. BT.create(B1.h, B1.w, elemsize, opt.workspace_allocator);
  315. if (BT.empty())
  316. return -100;
  317. transpose(B1.channel(Bp).depth(Bd), BT, opt);
  318. }
  319. else
  320. {
  321. BT = B1.channel(Bp).depth(Bd);
  322. }
  323. }
  324. Mat top_blob_p_q = top_blob.channel(p).depth(q);
  325. matmul_transb(A1.channel(Ap).depth(Ad), BT, top_blob_p_q, opt);
  326. }
  327. }
  328. }
  329. else
  330. {
  331. NCNN_LOGE("impossible matmul %d %d", Adims, Bdims);
  332. return -1;
  333. }
  334. return 0;
  335. }
  336. } // namespace ncnn