You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reduction.cpp 25 kB

8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "reduction.h"
  15. #include <float.h>
  16. #include <limits.h>
  17. #include <math.h>
  18. namespace ncnn {
  19. Reduction::Reduction()
  20. {
  21. one_blob_only = true;
  22. support_inplace = false;
  23. }
  24. int Reduction::load_param(const ParamDict& pd)
  25. {
  26. operation = pd.get(0, 0);
  27. reduce_all = pd.get(1, 1);
  28. coeff = pd.get(2, 1.f);
  29. axes = pd.get(3, Mat());
  30. keepdims = pd.get(4, 0);
  31. return 0;
  32. }
  33. template<typename Op, typename Op2>
  34. static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, const Option& opt)
  35. {
  36. Op op;
  37. Op2 op2;
  38. size_t elemsize = a.elemsize;
  39. int dims = a.dims;
  40. if (dims == 1)
  41. {
  42. int w = a.w;
  43. b.create(1, elemsize, opt.blob_allocator);
  44. const float* ptr = a;
  45. float sum = v0;
  46. for (int i = 0; i < w; i++)
  47. {
  48. sum = op(sum, ptr[i]);
  49. }
  50. b[0] = sum;
  51. return 0;
  52. }
  53. if (dims == 2)
  54. {
  55. int w = a.w;
  56. int h = a.h;
  57. if (reduce_w && reduce_h)
  58. {
  59. // w h -> X X
  60. b.create(1, elemsize, opt.blob_allocator);
  61. Mat sums(h, elemsize, opt.workspace_allocator);
  62. if (sums.empty())
  63. return -100;
  64. #pragma omp parallel for num_threads(opt.num_threads)
  65. for (int i = 0; i < h; i++)
  66. {
  67. const float* ptr = a.row(i);
  68. float sum = v0;
  69. for (int j = 0; j < w; j++)
  70. {
  71. sum = op(sum, ptr[j]);
  72. }
  73. sums[i] = sum;
  74. }
  75. float sum = v0;
  76. for (int i = 0; i < h; i++)
  77. {
  78. sum = op2(sum, sums[i]);
  79. }
  80. b[0] = sum;
  81. return 0;
  82. }
  83. if (reduce_w && !reduce_h)
  84. {
  85. // w h -> X h
  86. b.create(h, elemsize, opt.blob_allocator);
  87. #pragma omp parallel for num_threads(opt.num_threads)
  88. for (int i = 0; i < h; i++)
  89. {
  90. const float* ptr = a.row(i);
  91. float sum = v0;
  92. for (int j = 0; j < w; j++)
  93. {
  94. sum = op(sum, ptr[j]);
  95. }
  96. b[i] = sum;
  97. }
  98. return 0;
  99. }
  100. if (!reduce_w && reduce_h)
  101. {
  102. // w h -> w X
  103. b.create(w, elemsize, opt.blob_allocator);
  104. b.fill(v0);
  105. for (int i = 0; i < h; i++)
  106. {
  107. const float* ptr = a.row(i);
  108. for (int j = 0; j < w; j++)
  109. {
  110. b[j] = op(b[j], ptr[j]);
  111. }
  112. }
  113. return 0;
  114. }
  115. }
  116. if (dims == 3)
  117. {
  118. int w = a.w;
  119. int h = a.h;
  120. int channels = a.c;
  121. int size = w * h;
  122. if (reduce_w && reduce_h && reduce_c)
  123. {
  124. // w h c -> X X X
  125. b.create(1, elemsize, opt.blob_allocator);
  126. Mat sums(channels, elemsize, opt.workspace_allocator);
  127. if (sums.empty())
  128. return -100;
  129. #pragma omp parallel for num_threads(opt.num_threads)
  130. for (int q = 0; q < channels; q++)
  131. {
  132. const float* ptr = a.channel(q);
  133. float sum = v0;
  134. for (int i = 0; i < size; i++)
  135. {
  136. sum = op(sum, ptr[i]);
  137. }
  138. sums[q] = sum;
  139. }
  140. float sum = v0;
  141. for (int i = 0; i < channels; i++)
  142. {
  143. sum = op2(sum, sums[i]);
  144. }
  145. b[0] = sum;
  146. return 0;
  147. }
  148. if (reduce_w && reduce_h && !reduce_c)
  149. {
  150. // w h c -> X X c
  151. b.create(channels, elemsize, opt.blob_allocator);
  152. #pragma omp parallel for num_threads(opt.num_threads)
  153. for (int q = 0; q < channels; q++)
  154. {
  155. const float* ptr = a.channel(q);
  156. float sum = v0;
  157. for (int i = 0; i < size; i++)
  158. {
  159. sum = op(sum, ptr[i]);
  160. }
  161. b[q] = sum;
  162. }
  163. return 0;
  164. }
  165. if (reduce_w && !reduce_h && !reduce_c)
  166. {
  167. // w h c -> X h c
  168. b.create(h, channels, elemsize, opt.blob_allocator);
  169. #pragma omp parallel for num_threads(opt.num_threads)
  170. for (int q = 0; q < channels; q++)
  171. {
  172. const float* ptr = a.channel(q);
  173. float* outptr = b.row(q);
  174. for (int i = 0; i < h; i++)
  175. {
  176. float sum = v0;
  177. for (int j = 0; j < w; j++)
  178. {
  179. sum = op(sum, ptr[j]);
  180. }
  181. outptr[i] = sum;
  182. ptr += w;
  183. }
  184. }
  185. return 0;
  186. }
  187. if (reduce_w && !reduce_h && reduce_c)
  188. {
  189. // w h c -> X h X
  190. b.create(h, elemsize, opt.blob_allocator);
  191. Mat mins(1, h, channels, elemsize, opt.workspace_allocator);
  192. if (mins.empty())
  193. return -100;
  194. mins.fill(v0);
  195. #pragma omp parallel for num_threads(opt.num_threads)
  196. for (int q = 0; q < channels; q++)
  197. {
  198. const float* ptr = a.channel(q);
  199. float* mins_ptr = mins.channel(q);
  200. for (int i = 0; i < h; i++)
  201. {
  202. float sum = v0;
  203. for (int j = 0; j < w; j++)
  204. {
  205. sum = op(sum, ptr[j]);
  206. }
  207. mins_ptr[i] = sum;
  208. ptr += w;
  209. }
  210. }
  211. b.fill(v0);
  212. for (int q = 0; q < channels; q++)
  213. {
  214. const float* mins_ptr = mins.channel(q);
  215. for (int i = 0; i < h; i++)
  216. {
  217. b[i] = op2(b[i], mins_ptr[i]);
  218. }
  219. }
  220. return 0;
  221. }
  222. if (!reduce_w && reduce_h && reduce_c)
  223. {
  224. // w h c -> w X X
  225. b.create(w, elemsize, opt.blob_allocator);
  226. Mat mins(w, 1, channels, elemsize, opt.workspace_allocator);
  227. if (mins.empty())
  228. return -100;
  229. mins.fill(v0);
  230. #pragma omp parallel for num_threads(opt.num_threads)
  231. for (int q = 0; q < channels; q++)
  232. {
  233. const float* ptr = a.channel(q);
  234. float* mins_ptr = mins.channel(q);
  235. for (int i = 0; i < h; i++)
  236. {
  237. for (int j = 0; j < w; j++)
  238. {
  239. mins_ptr[j] = op(mins_ptr[j], ptr[j]);
  240. }
  241. ptr += w;
  242. }
  243. }
  244. b.fill(v0);
  245. for (int q = 0; q < channels; q++)
  246. {
  247. const float* mins_ptr = mins.channel(q);
  248. for (int j = 0; j < w; j++)
  249. {
  250. b[j] = op2(b[j], mins_ptr[j]);
  251. }
  252. }
  253. return 0;
  254. }
  255. if (!reduce_w && !reduce_h && reduce_c)
  256. {
  257. // w h c -> w h X
  258. b.create(w, h, elemsize, opt.blob_allocator);
  259. b.fill(v0);
  260. for (int q = 0; q < channels; q++)
  261. {
  262. const float* ptr = a.channel(q);
  263. for (int i = 0; i < size; i++)
  264. {
  265. b[i] = op(b[i], ptr[i]);
  266. }
  267. }
  268. return 0;
  269. }
  270. if (!reduce_w && reduce_h && !reduce_c)
  271. {
  272. // w h c -> w X c
  273. b.create(w, channels, elemsize, opt.blob_allocator);
  274. b.fill(v0);
  275. #pragma omp parallel for num_threads(opt.num_threads)
  276. for (int q = 0; q < channels; q++)
  277. {
  278. const float* ptr = a.channel(q);
  279. float* outptr = b.row(q);
  280. for (int i = 0; i < h; i++)
  281. {
  282. for (int j = 0; j < w; j++)
  283. {
  284. outptr[j] = op(outptr[j], ptr[j]);
  285. }
  286. ptr += w;
  287. }
  288. }
  289. return 0;
  290. }
  291. }
  292. return 0;
  293. }
  294. template<typename Op, typename Op2>
  295. static int reduction_op_keepdims(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, const Option& opt)
  296. {
  297. Op op;
  298. Op2 op2;
  299. size_t elemsize = a.elemsize;
  300. int dims = a.dims;
  301. if (dims == 1)
  302. {
  303. int w = a.w;
  304. b.create(1, elemsize, opt.blob_allocator);
  305. const float* ptr = a;
  306. float sum = v0;
  307. for (int i = 0; i < w; i++)
  308. {
  309. sum = op(sum, ptr[i]);
  310. }
  311. b[0] = sum;
  312. return 0;
  313. }
  314. if (dims == 2)
  315. {
  316. int w = a.w;
  317. int h = a.h;
  318. if (reduce_w && reduce_h)
  319. {
  320. // w h -> 1 1
  321. b.create(1, 1, elemsize, opt.blob_allocator);
  322. Mat sums(h, elemsize, opt.workspace_allocator);
  323. if (sums.empty())
  324. return -100;
  325. #pragma omp parallel for num_threads(opt.num_threads)
  326. for (int i = 0; i < h; i++)
  327. {
  328. const float* ptr = a.row(i);
  329. float sum = v0;
  330. for (int j = 0; j < w; j++)
  331. {
  332. sum = op(sum, ptr[j]);
  333. }
  334. sums[i] = sum;
  335. }
  336. float sum = v0;
  337. for (int i = 0; i < h; i++)
  338. {
  339. sum = op2(sum, sums[i]);
  340. }
  341. b[0] = sum;
  342. return 0;
  343. }
  344. if (reduce_w && !reduce_h)
  345. {
  346. // w h -> 1 h
  347. b.create(1, h, elemsize, opt.blob_allocator);
  348. #pragma omp parallel for num_threads(opt.num_threads)
  349. for (int i = 0; i < h; i++)
  350. {
  351. const float* ptr = a.row(i);
  352. float sum = v0;
  353. for (int j = 0; j < w; j++)
  354. {
  355. sum = op(sum, ptr[j]);
  356. }
  357. b[i] = sum;
  358. }
  359. return 0;
  360. }
  361. if (!reduce_w && reduce_h)
  362. {
  363. // w h -> w 1
  364. b.create(w, 1, elemsize, opt.blob_allocator);
  365. b.fill(v0);
  366. for (int i = 0; i < h; i++)
  367. {
  368. const float* ptr = a.row(i);
  369. for (int j = 0; j < w; j++)
  370. {
  371. b[j] = op(b[j], ptr[j]);
  372. }
  373. }
  374. return 0;
  375. }
  376. }
  377. if (dims == 3)
  378. {
  379. int w = a.w;
  380. int h = a.h;
  381. int channels = a.c;
  382. int size = w * h;
  383. if (reduce_w && reduce_h && reduce_c)
  384. {
  385. // w h c -> 1 1 1
  386. b.create(1, 1, 1, elemsize, opt.blob_allocator);
  387. Mat sums(channels, elemsize, opt.workspace_allocator);
  388. if (sums.empty())
  389. return -100;
  390. #pragma omp parallel for num_threads(opt.num_threads)
  391. for (int q = 0; q < channels; q++)
  392. {
  393. const float* ptr = a.channel(q);
  394. float sum = v0;
  395. for (int i = 0; i < size; i++)
  396. {
  397. sum = op(sum, ptr[i]);
  398. }
  399. sums[q] = sum;
  400. }
  401. float sum = v0;
  402. for (int i = 0; i < channels; i++)
  403. {
  404. sum = op2(sum, sums[i]);
  405. }
  406. b[0] = sum;
  407. return 0;
  408. }
  409. if (reduce_w && reduce_h && !reduce_c)
  410. {
  411. // w h c -> 1 1 c
  412. b.create(1, 1, channels, elemsize, opt.blob_allocator);
  413. #pragma omp parallel for num_threads(opt.num_threads)
  414. for (int q = 0; q < channels; q++)
  415. {
  416. const float* ptr = a.channel(q);
  417. float* outptr = b.channel(q);
  418. float sum = v0;
  419. for (int i = 0; i < size; i++)
  420. {
  421. sum = op(sum, ptr[i]);
  422. }
  423. outptr[0] = sum;
  424. }
  425. return 0;
  426. }
  427. if (reduce_w && !reduce_h && !reduce_c)
  428. {
  429. // w h c -> 1 h c
  430. b.create(1, h, channels, elemsize, opt.blob_allocator);
  431. #pragma omp parallel for num_threads(opt.num_threads)
  432. for (int q = 0; q < channels; q++)
  433. {
  434. const float* ptr = a.channel(q);
  435. float* outptr = b.channel(q);
  436. for (int i = 0; i < h; i++)
  437. {
  438. float sum = v0;
  439. for (int j = 0; j < w; j++)
  440. {
  441. sum = op(sum, ptr[j]);
  442. }
  443. outptr[i] = sum;
  444. ptr += w;
  445. }
  446. }
  447. return 0;
  448. }
  449. if (reduce_w && !reduce_h && reduce_c)
  450. {
  451. // w h c -> 1 h 1
  452. b.create(1, h, 1, elemsize, opt.blob_allocator);
  453. Mat mins(1, h, channels, elemsize, opt.workspace_allocator);
  454. if (mins.empty())
  455. return -100;
  456. mins.fill(v0);
  457. #pragma omp parallel for num_threads(opt.num_threads)
  458. for (int q = 0; q < channels; q++)
  459. {
  460. const float* ptr = a.channel(q);
  461. float* mins_ptr = mins.channel(q);
  462. for (int i = 0; i < h; i++)
  463. {
  464. float sum = v0;
  465. for (int j = 0; j < w; j++)
  466. {
  467. sum = op(sum, ptr[j]);
  468. }
  469. mins_ptr[i] = sum;
  470. ptr += w;
  471. }
  472. }
  473. b.fill(v0);
  474. for (int q = 0; q < channels; q++)
  475. {
  476. const float* mins_ptr = mins.channel(q);
  477. for (int i = 0; i < h; i++)
  478. {
  479. b[i] = op2(b[i], mins_ptr[i]);
  480. }
  481. }
  482. return 0;
  483. }
  484. if (!reduce_w && reduce_h && reduce_c)
  485. {
  486. // w h c -> w 1 1
  487. b.create(w, 1, 1, elemsize, opt.blob_allocator);
  488. Mat mins(w, 1, channels, elemsize, opt.workspace_allocator);
  489. if (mins.empty())
  490. return -100;
  491. mins.fill(v0);
  492. #pragma omp parallel for num_threads(opt.num_threads)
  493. for (int q = 0; q < channels; q++)
  494. {
  495. const float* ptr = a.channel(q);
  496. float* mins_ptr = mins.channel(q);
  497. for (int i = 0; i < h; i++)
  498. {
  499. for (int j = 0; j < w; j++)
  500. {
  501. mins_ptr[j] = op(mins_ptr[j], ptr[j]);
  502. }
  503. ptr += w;
  504. }
  505. }
  506. b.fill(v0);
  507. for (int q = 0; q < channels; q++)
  508. {
  509. const float* mins_ptr = mins.channel(q);
  510. for (int j = 0; j < w; j++)
  511. {
  512. b[j] = op2(b[j], mins_ptr[j]);
  513. }
  514. }
  515. return 0;
  516. }
  517. if (!reduce_w && !reduce_h && reduce_c)
  518. {
  519. // w h c -> w h 1
  520. b.create(w, h, 1, elemsize, opt.blob_allocator);
  521. b.fill(v0);
  522. for (int q = 0; q < channels; q++)
  523. {
  524. const float* ptr = a.channel(q);
  525. for (int i = 0; i < size; i++)
  526. {
  527. b[i] = op(b[i], ptr[i]);
  528. }
  529. }
  530. return 0;
  531. }
  532. if (!reduce_w && reduce_h && !reduce_c)
  533. {
  534. // w h c -> w 1 c
  535. b.create(w, 1, channels, elemsize, opt.blob_allocator);
  536. b.fill(v0);
  537. #pragma omp parallel for num_threads(opt.num_threads)
  538. for (int q = 0; q < channels; q++)
  539. {
  540. const float* ptr = a.channel(q);
  541. float* outptr = b.channel(q);
  542. for (int i = 0; i < h; i++)
  543. {
  544. for (int j = 0; j < w; j++)
  545. {
  546. outptr[j] = op(outptr[j], ptr[j]);
  547. }
  548. ptr += w;
  549. }
  550. }
  551. return 0;
  552. }
  553. }
  554. return 0;
  555. }
  556. template<typename MathOp>
  557. static int reduction_post_process(Mat& a, float coeff, const Option& opt)
  558. {
  559. MathOp mathop;
  560. int dims = a.dims;
  561. if (dims == 1)
  562. {
  563. int w = a.w;
  564. #pragma omp parallel for num_threads(opt.num_threads)
  565. for (int i = 0; i < w; i++)
  566. a[i] = mathop(a[i]) * coeff;
  567. }
  568. else if (dims == 2)
  569. {
  570. int size = a.w * a.h;
  571. #pragma omp parallel for num_threads(opt.num_threads)
  572. for (int i = 0; i < size; i++)
  573. a[i] = mathop(a[i]) * coeff;
  574. }
  575. else if (dims == 3)
  576. {
  577. int c = a.c;
  578. int size = a.w * a.h;
  579. if (c == 1)
  580. {
  581. #pragma omp parallel for num_threads(opt.num_threads)
  582. for (int i = 0; i < size; i++)
  583. a[i] = mathop(a[i]) * coeff;
  584. }
  585. else
  586. {
  587. #pragma omp parallel for num_threads(opt.num_threads)
  588. for (int q = 0; q < c; q++)
  589. {
  590. float* outptr = a.channel(q);
  591. for (int i = 0; i < size; i++)
  592. outptr[i] = mathop(outptr[i]) * coeff;
  593. }
  594. }
  595. }
  596. return 0;
  597. }
  598. template<typename Op, typename Op2, typename Op3>
  599. static int reduction(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_c, bool post_process, float coeff, int keepdims, const Option& opt)
  600. {
  601. int ret;
  602. if (keepdims)
  603. ret = reduction_op_keepdims<Op, Op2>(a, b, v0, reduce_w, reduce_h, reduce_c, opt);
  604. else
  605. ret = reduction_op<Op, Op2>(a, b, v0, reduce_w, reduce_h, reduce_c, opt);
  606. if (ret != 0)
  607. return -100;
  608. if (post_process || fabs(coeff - 1.f) > FLT_EPSILON)
  609. {
  610. ret = reduction_post_process<Op3>(b, coeff, opt);
  611. if (ret != 0)
  612. return -100;
  613. }
  614. return ret;
  615. }
  616. template<typename T>
  617. struct post_process_identity
  618. {
  619. T operator()(const T& x) const
  620. {
  621. return x;
  622. }
  623. };
  624. template<typename T>
  625. struct post_process_sqrt
  626. {
  627. T operator()(const T& x) const
  628. {
  629. return static_cast<T>(sqrt(x));
  630. }
  631. };
  632. template<typename T>
  633. struct post_process_log
  634. {
  635. T operator()(const T& x) const
  636. {
  637. return static_cast<T>(log(x));
  638. }
  639. };
  640. template<typename T>
  641. struct reduction_op_add
  642. {
  643. T operator()(const T& x, const T& y) const
  644. {
  645. return x + y;
  646. }
  647. };
  648. template<typename T>
  649. struct reduction_op_mul
  650. {
  651. T operator()(const T& x, const T& y) const
  652. {
  653. return x * y;
  654. }
  655. };
  656. template<typename T>
  657. struct reduction_op_asum
  658. {
  659. T operator()(const T& x, const T& y) const
  660. {
  661. return static_cast<T>(x + fabs(y));
  662. }
  663. };
  664. template<typename T>
  665. struct reduction_op_sumsq
  666. {
  667. T operator()(const T& x, const T& y) const
  668. {
  669. return x + y * y;
  670. }
  671. };
  672. template<typename T>
  673. struct reduction_op_sumsexp
  674. {
  675. T operator()(const T& x, const T& y) const
  676. {
  677. return static_cast<T>(x + exp(y));
  678. }
  679. };
  680. template<typename T>
  681. struct reduction_op_max
  682. {
  683. T operator()(const T& x, const T& y) const
  684. {
  685. return std::max(x, y);
  686. }
  687. };
  688. template<typename T>
  689. struct reduction_op_min
  690. {
  691. T operator()(const T& x, const T& y) const
  692. {
  693. return std::min(x, y);
  694. }
  695. };
  696. int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
  697. {
  698. int dims = bottom_blob.dims;
  699. int axes_flag[3] = {0};
  700. bool reduce_w = false;
  701. bool reduce_h = false;
  702. bool reduce_c = false;
  703. if (reduce_all)
  704. {
  705. reduce_w = true;
  706. reduce_h = true;
  707. reduce_c = true;
  708. }
  709. else
  710. {
  711. const int* axes_ptr = axes;
  712. int reduced_axes_num = axes.w;
  713. for (int i = 0; i < reduced_axes_num; i++)
  714. {
  715. int axis = axes_ptr[i];
  716. // handle negative axis
  717. if (axis < 0)
  718. axis += dims;
  719. axes_flag[axis] = 1;
  720. }
  721. if (dims == 1)
  722. {
  723. reduce_w = true;
  724. }
  725. else if (dims == 2)
  726. {
  727. if (axes_flag[0] == 1) reduce_h = true;
  728. if (axes_flag[1] == 1) reduce_w = true;
  729. }
  730. else if (dims == 3)
  731. {
  732. if (axes_flag[0] == 1) reduce_c = true;
  733. if (axes_flag[1] == 1) reduce_h = true;
  734. if (axes_flag[2] == 1) reduce_w = true;
  735. }
  736. }
  737. if (operation == ReductionOp_SUM)
  738. return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  739. if (operation == ReductionOp_ASUM)
  740. return reduction<reduction_op_asum<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  741. if (operation == ReductionOp_SUMSQ)
  742. return reduction<reduction_op_sumsq<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  743. if (operation == ReductionOp_MEAN)
  744. {
  745. int scale = 1;
  746. int dims = bottom_blob.dims;
  747. if (dims == 1)
  748. {
  749. scale = bottom_blob.w;
  750. }
  751. else if (dims == 2)
  752. {
  753. if (reduce_w) scale *= bottom_blob.w;
  754. if (reduce_h) scale *= bottom_blob.h;
  755. }
  756. else if (dims == 3)
  757. {
  758. if (reduce_w) scale *= bottom_blob.w;
  759. if (reduce_h) scale *= bottom_blob.h;
  760. if (reduce_c) scale *= bottom_blob.c;
  761. }
  762. float coeff_mean = coeff / scale;
  763. return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, coeff_mean, keepdims, opt);
  764. }
  765. if (operation == ReductionOp_MAX)
  766. return reduction<reduction_op_max<float>, reduction_op_max<float>, post_process_identity<float> >(bottom_blob, top_blob, -FLT_MAX, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  767. if (operation == ReductionOp_MIN)
  768. return reduction<reduction_op_min<float>, reduction_op_min<float>, post_process_identity<float> >(bottom_blob, top_blob, FLT_MAX, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  769. if (operation == ReductionOp_PROD)
  770. return reduction<reduction_op_mul<float>, reduction_op_mul<float>, post_process_identity<float> >(bottom_blob, top_blob, 1.f, reduce_w, reduce_h, reduce_c, false, coeff, keepdims, opt);
  771. if (operation == ReductionOp_L1)
  772. return reduction<reduction_op_asum<float>, reduction_op_add<float>, post_process_identity<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, false, 1.f, keepdims, opt);
  773. if (operation == ReductionOp_L2)
  774. return reduction<reduction_op_sumsq<float>, reduction_op_add<float>, post_process_sqrt<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
  775. if (operation == ReductionOp_LogSum)
  776. return reduction<reduction_op_add<float>, reduction_op_add<float>, post_process_log<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
  777. if (operation == ReductionOp_LogSumExp)
  778. return reduction<reduction_op_sumsexp<float>, reduction_op_add<float>, post_process_log<float> >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_c, true, 1.f, keepdims, opt);
  779. return 0;
  780. }
  781. } // namespace ncnn