You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

simpleomp.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "platform.h"
  15. #if NCNN_SIMPLEOMP
  16. #include "simpleomp.h"
  17. #include "cpu.h" // ncnn::get_cpu_count()
  18. #include <stdio.h>
  19. #include <stdlib.h>
  20. #include <string.h>
  21. #include <stdint.h>
  22. #include <stdarg.h>
  23. extern "C" typedef void (*kmpc_micro)(int32_t* gtid, int32_t* tid, ...);
  24. extern "C" typedef void (*kmpc_micro_0)(int32_t* gtid, int32_t* tid);
  25. extern "C" typedef void (*kmpc_micro_1)(int32_t* gtid, int32_t* tid, void*);
  26. extern "C" typedef void (*kmpc_micro_2)(int32_t* gtid, int32_t* tid, void*, void*);
  27. extern "C" typedef void (*kmpc_micro_3)(int32_t* gtid, int32_t* tid, void*, void*, void*);
  28. extern "C" typedef void (*kmpc_micro_4)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*);
  29. extern "C" typedef void (*kmpc_micro_5)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*);
  30. extern "C" typedef void (*kmpc_micro_6)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*);
  31. extern "C" typedef void (*kmpc_micro_7)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*);
  32. extern "C" typedef void (*kmpc_micro_8)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*);
  33. extern "C" typedef void (*kmpc_micro_9)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  34. extern "C" typedef void (*kmpc_micro_10)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  35. extern "C" typedef void (*kmpc_micro_11)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  36. extern "C" typedef void (*kmpc_micro_12)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  37. extern "C" typedef void (*kmpc_micro_13)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  38. extern "C" typedef void (*kmpc_micro_14)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  39. extern "C" typedef void (*kmpc_micro_15)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*);
  40. #ifdef __cplusplus
  41. extern "C" {
  42. #endif
  43. static void init_g_kmp_global();
  44. static void* kmp_threadfunc(void* args);
  45. #ifdef __cplusplus
  46. } // extern "C"
  47. #endif
  48. namespace ncnn {
  49. class KMPTask
  50. {
  51. public:
  52. // per-team
  53. kmpc_micro fn;
  54. int argc;
  55. void** argv;
  56. int num_threads;
  57. // per-task
  58. int thread_num;
  59. // finish status
  60. int* num_threads_to_wait;
  61. Mutex* finish_lock;
  62. ConditionVariable* finish_condition;
  63. };
  64. class KMPTaskQueue
  65. {
  66. public:
  67. KMPTaskQueue(int _max_size)
  68. {
  69. max_size = _max_size;
  70. tasks = new KMPTask*[max_size];
  71. size = 0;
  72. front = 0;
  73. back = 0;
  74. }
  75. ~KMPTaskQueue()
  76. {
  77. delete[] tasks;
  78. }
  79. void dispatch(KMPTask* v, int n)
  80. {
  81. lock.lock();
  82. if (size + n > max_size)
  83. {
  84. lock.unlock();
  85. for (int i = 0; i < n; i++)
  86. {
  87. put(&v[i]);
  88. }
  89. return;
  90. }
  91. for (int i = 0; i < n; i++)
  92. {
  93. tasks[back] = &v[i];
  94. back++;
  95. if (back == max_size)
  96. back = 0;
  97. }
  98. size += n;
  99. lock.unlock();
  100. condition.signal();
  101. }
  102. void put(KMPTask* v)
  103. {
  104. lock.lock();
  105. while (size >= max_size)
  106. {
  107. condition.wait(lock);
  108. }
  109. tasks[back] = v;
  110. back++;
  111. if (back == max_size)
  112. back = 0;
  113. size++;
  114. lock.unlock();
  115. condition.signal();
  116. }
  117. void get(KMPTask*& v)
  118. {
  119. lock.lock();
  120. while (size == 0)
  121. {
  122. condition.wait(lock);
  123. }
  124. v = tasks[front];
  125. front++;
  126. if (front == max_size)
  127. front = 0;
  128. size--;
  129. lock.unlock();
  130. condition.signal();
  131. }
  132. private:
  133. Mutex lock;
  134. ConditionVariable condition;
  135. // ring buffer queue
  136. int max_size;
  137. KMPTask** tasks;
  138. int size;
  139. int front;
  140. int back;
  141. };
  142. class KMPGlobal
  143. {
  144. public:
  145. KMPGlobal()
  146. {
  147. kmp_max_threads = 0;
  148. kmp_threads = 0;
  149. kmp_threads_tid = 0;
  150. kmp_task_queue = 0;
  151. }
  152. ~KMPGlobal()
  153. {
  154. deinit();
  155. }
  156. void try_init()
  157. {
  158. pthread_once(&is_initialized, init_g_kmp_global);
  159. }
  160. public:
  161. static pthread_once_t is_initialized;
  162. void init()
  163. {
  164. // NCNN_LOGE("KMPGlobal init");
  165. kmp_max_threads = ncnn::get_cpu_count();
  166. kmp_task_queue = new ncnn::KMPTaskQueue(std::max(kmp_max_threads * 4, 16));
  167. if (kmp_max_threads > 1)
  168. {
  169. kmp_threads = new ncnn::Thread*[kmp_max_threads - 1];
  170. kmp_threads_tid = new int[kmp_max_threads - 1];
  171. for (int i = 0; i < kmp_max_threads - 1; i++)
  172. {
  173. kmp_threads_tid[i] = i + 1;
  174. kmp_threads[i] = new ncnn::Thread(kmp_threadfunc, (void*)&kmp_threads_tid[i]);
  175. }
  176. }
  177. }
  178. void deinit()
  179. {
  180. // NCNN_LOGE("KMPGlobal deinit");
  181. if (kmp_max_threads > 1)
  182. {
  183. // TODO portable stack allocation
  184. ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((kmp_max_threads - 1) * sizeof(ncnn::KMPTask));
  185. for (int i = 0; i < kmp_max_threads - 1; i++)
  186. {
  187. tasks[i].fn = 0;
  188. tasks[i].argc = 0;
  189. tasks[i].argv = (void**)0;
  190. tasks[i].num_threads = kmp_max_threads;
  191. tasks[i].thread_num = i + 1;
  192. tasks[i].num_threads_to_wait = 0;
  193. tasks[i].finish_lock = 0;
  194. tasks[i].finish_condition = 0;
  195. }
  196. // dispatch 1 ~ kmp_max_threads
  197. kmp_task_queue->dispatch(tasks, kmp_max_threads - 1);
  198. for (int i = 0; i < kmp_max_threads - 1; i++)
  199. {
  200. #ifndef __EMSCRIPTEN__
  201. // FIXME emscripten complains
  202. // pthread_join attempted on thread 12345678,
  203. // which does not point to a valid thread, or does not exist anymore!
  204. kmp_threads[i]->join();
  205. #endif
  206. delete kmp_threads[i];
  207. }
  208. delete[] kmp_threads;
  209. delete[] kmp_threads_tid;
  210. }
  211. delete kmp_task_queue;
  212. }
  213. public:
  214. int kmp_max_threads;
  215. ncnn::Thread** kmp_threads;
  216. int* kmp_threads_tid;
  217. ncnn::KMPTaskQueue* kmp_task_queue;
  218. };
  219. } // namespace ncnn
  220. pthread_once_t ncnn::KMPGlobal::is_initialized = PTHREAD_ONCE_INIT;
  221. static ncnn::KMPGlobal g_kmp_global;
  222. static ncnn::ThreadLocalStorage tls_num_threads;
  223. static ncnn::ThreadLocalStorage tls_thread_num;
  224. static void init_g_kmp_global()
  225. {
  226. g_kmp_global.init();
  227. }
  228. #ifdef __cplusplus
  229. extern "C" {
  230. #endif
  231. int omp_get_max_threads()
  232. {
  233. return ncnn::get_cpu_count();
  234. }
  235. int omp_get_dynamic()
  236. {
  237. return 1;
  238. }
  239. void omp_set_dynamic(int /*dynamic*/)
  240. {
  241. // always dynamic, ignore
  242. }
  243. void omp_set_num_threads(int num_threads)
  244. {
  245. tls_num_threads.set(reinterpret_cast<void*>((size_t)std::max(num_threads, 1)));
  246. }
  247. int omp_get_num_threads()
  248. {
  249. return std::max((int)reinterpret_cast<size_t>(tls_num_threads.get()), 1);
  250. }
  251. int omp_get_thread_num()
  252. {
  253. return (int)reinterpret_cast<size_t>(tls_thread_num.get());
  254. }
  255. int kmp_get_blocktime()
  256. {
  257. return 0;
  258. }
  259. void kmp_set_blocktime(int /*blocktime*/)
  260. {
  261. // always passive, ignore
  262. }
  263. static int kmp_invoke_microtask(kmpc_micro fn, int gtid, int tid, int argc, void** argv)
  264. {
  265. // fprintf(stderr, "__kmp_invoke_microtask #%lu %d %d %d\n", gettid(), gtid, tid, argc);
  266. switch (argc)
  267. {
  268. case 0:
  269. (*(kmpc_micro_0)fn)(&gtid, &tid);
  270. break;
  271. case 1:
  272. (*(kmpc_micro_1)fn)(&gtid, &tid, argv[0]);
  273. break;
  274. case 2:
  275. (*(kmpc_micro_2)fn)(&gtid, &tid, argv[0], argv[1]);
  276. break;
  277. case 3:
  278. (*(kmpc_micro_3)fn)(&gtid, &tid, argv[0], argv[1], argv[2]);
  279. break;
  280. case 4:
  281. (*(kmpc_micro_4)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3]);
  282. break;
  283. case 5:
  284. (*(kmpc_micro_5)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4]);
  285. break;
  286. case 6:
  287. (*(kmpc_micro_6)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]);
  288. break;
  289. case 7:
  290. (*(kmpc_micro_7)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]);
  291. break;
  292. case 8:
  293. (*(kmpc_micro_8)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7]);
  294. break;
  295. case 9:
  296. (*(kmpc_micro_9)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8]);
  297. break;
  298. case 10:
  299. (*(kmpc_micro_10)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9]);
  300. break;
  301. case 11:
  302. (*(kmpc_micro_11)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10]);
  303. break;
  304. case 12:
  305. (*(kmpc_micro_12)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11]);
  306. break;
  307. case 13:
  308. (*(kmpc_micro_13)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12]);
  309. break;
  310. case 14:
  311. (*(kmpc_micro_14)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13]);
  312. break;
  313. case 15:
  314. (*(kmpc_micro_15)fn)(&gtid, &tid, argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6], argv[7], argv[8], argv[9], argv[10], argv[11], argv[12], argv[13], argv[14]);
  315. break;
  316. default:
  317. // assert never reach here
  318. break;
  319. }
  320. return 0;
  321. }
  322. static void* kmp_threadfunc(void* args)
  323. {
  324. int tid = *(int*)args;
  325. for (;;)
  326. {
  327. ncnn::KMPTask* task;
  328. g_kmp_global.kmp_task_queue->get(task);
  329. // fprintf(stderr, "get %d\n", tid);
  330. if (!task->fn)
  331. break;
  332. tls_num_threads.set(reinterpret_cast<void*>((size_t)task->num_threads));
  333. tls_thread_num.set(reinterpret_cast<void*>((size_t)task->thread_num));
  334. kmp_invoke_microtask(task->fn, task->thread_num, tid, task->argc, task->argv);
  335. // update finished
  336. {
  337. task->finish_lock->lock();
  338. *task->num_threads_to_wait = *task->num_threads_to_wait - 1;
  339. if (*task->num_threads_to_wait == 0)
  340. {
  341. task->finish_condition->signal();
  342. }
  343. task->finish_lock->unlock();
  344. }
  345. }
  346. // fprintf(stderr, "exit\n");
  347. return 0;
  348. }
  349. int32_t __kmpc_global_thread_num(void* /*loc*/)
  350. {
  351. // NCNN_LOGE("__kmpc_global_thread_num");
  352. return 0;
  353. }
  354. void __kmpc_push_num_threads(void* /*loc*/, int32_t /*gtid*/, int32_t num_threads)
  355. {
  356. // NCNN_LOGE("__kmpc_push_num_threads %d", num_threads);
  357. omp_set_num_threads(num_threads);
  358. }
  359. void __kmpc_fork_call(void* /*loc*/, int32_t argc, kmpc_micro fn, ...)
  360. {
  361. g_kmp_global.try_init();
  362. // NCNN_LOGE("__kmpc_fork_call %d", argc);
  363. int num_threads = omp_get_num_threads();
  364. // build argv
  365. void* argv[16];
  366. {
  367. va_list ap;
  368. va_start(ap, fn);
  369. for (int i = 0; i < argc; i++)
  370. argv[i] = va_arg(ap, void*);
  371. va_end(ap);
  372. }
  373. if (g_kmp_global.kmp_max_threads == 1 || num_threads == 1)
  374. {
  375. for (int i = 0; i < num_threads; i++)
  376. {
  377. tls_thread_num.set(reinterpret_cast<void*>((size_t)i));
  378. kmp_invoke_microtask(fn, 0, 0, argc, argv);
  379. }
  380. return;
  381. }
  382. int num_threads_to_wait = num_threads - 1;
  383. ncnn::Mutex finish_lock;
  384. ncnn::ConditionVariable finish_condition;
  385. // TODO portable stack allocation
  386. ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((num_threads - 1) * sizeof(ncnn::KMPTask));
  387. for (int i = 0; i < num_threads - 1; i++)
  388. {
  389. tasks[i].fn = fn;
  390. tasks[i].argc = argc;
  391. tasks[i].argv = (void**)argv;
  392. tasks[i].num_threads = num_threads;
  393. tasks[i].thread_num = i + 1;
  394. tasks[i].num_threads_to_wait = &num_threads_to_wait;
  395. tasks[i].finish_lock = &finish_lock;
  396. tasks[i].finish_condition = &finish_condition;
  397. }
  398. // dispatch 1 ~ num_threads
  399. g_kmp_global.kmp_task_queue->dispatch(tasks, num_threads - 1);
  400. // dispatch 0
  401. {
  402. tls_thread_num.set(reinterpret_cast<void*>((size_t)0));
  403. kmp_invoke_microtask(fn, 0, 0, argc, argv);
  404. }
  405. // wait for finished
  406. {
  407. finish_lock.lock();
  408. if (num_threads_to_wait != 0)
  409. {
  410. finish_condition.wait(finish_lock);
  411. }
  412. finish_lock.unlock();
  413. }
  414. }
  415. void __kmpc_for_static_init_4(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int32_t* lower, int32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
  416. {
  417. // NCNN_LOGE("__kmpc_for_static_init_4");
  418. int num_threads = omp_get_num_threads();
  419. // TODO only support i++
  420. int32_t count = *upper - *lower + 1;
  421. int32_t threads = std::min(count, (int32_t)num_threads);
  422. int32_t count_per_thread = count / threads;
  423. int32_t remain = count % threads;
  424. *last = gtid == (int32_t)(threads - 1);
  425. *lower = gtid * count_per_thread + std::min(remain, gtid);
  426. *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, gtid + 1) - 1, *upper);
  427. }
  428. void __kmpc_for_static_init_4u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint32_t* lower, uint32_t* upper, int32_t* /*stride*/, int32_t /*incr*/, int32_t /*chunk*/)
  429. {
  430. // NCNN_LOGE("__kmpc_for_static_init_4u");
  431. int num_threads = omp_get_num_threads();
  432. // TODO only support i++
  433. uint32_t count = *upper - *lower + 1;
  434. uint32_t threads = std::min(count, (uint32_t)num_threads);
  435. uint32_t count_per_thread = count / threads;
  436. uint32_t remain = count % threads;
  437. *last = gtid == (int32_t)(threads - 1);
  438. *lower = gtid * count_per_thread + std::min(remain, (uint32_t)gtid);
  439. *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint32_t)gtid + 1) - 1, *upper);
  440. }
  441. void __kmpc_for_static_init_8(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, int64_t* lower, int64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
  442. {
  443. // NCNN_LOGE("__kmpc_for_static_init_8");
  444. int num_threads = omp_get_num_threads();
  445. // TODO only support i++
  446. int64_t count = *upper - *lower + 1;
  447. int64_t threads = std::min(count, (int64_t)num_threads);
  448. int64_t count_per_thread = count / threads;
  449. int64_t remain = count % threads;
  450. *last = gtid == (int64_t)(threads - 1);
  451. *lower = gtid * count_per_thread + std::min(remain, (int64_t)gtid);
  452. *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (int64_t)gtid + 1) - 1, *upper);
  453. }
  454. void __kmpc_for_static_init_8u(void* /*loc*/, int32_t gtid, int32_t /*sched*/, int32_t* last, uint64_t* lower, uint64_t* upper, int64_t* /*stride*/, int64_t /*incr*/, int64_t /*chunk*/)
  455. {
  456. // NCNN_LOGE("__kmpc_for_static_init_8u");
  457. int num_threads = omp_get_num_threads();
  458. // TODO only support i++
  459. uint64_t count = *upper - *lower + 1;
  460. uint64_t threads = std::min(count, (uint64_t)num_threads);
  461. uint64_t count_per_thread = count / threads;
  462. uint64_t remain = count % threads;
  463. *last = gtid == (int64_t)(threads - 1);
  464. *lower = gtid * count_per_thread + std::min(remain, (uint64_t)gtid);
  465. *upper = std::min((gtid + 1) * count_per_thread + std::min(remain, (uint64_t)gtid + 1) - 1, *upper);
  466. }
  467. void __kmpc_for_static_fini(void* /*loc*/, int32_t gtid)
  468. {
  469. // NCNN_LOGE("__kmpc_for_static_fini");
  470. (void)gtid;
  471. }
  472. #ifdef __cplusplus
  473. } // extern "C"
  474. #endif
  475. #endif // NCNN_SIMPLEOMP