You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

softmax_riscv.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. // Xavier Hsinyuan is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 Xavier Hsinyuan <me@lstlx.com>. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "softmax_riscv.h"
  15. #include <float.h>
  16. #if __riscv_vector
  17. #include <riscv_vector.h>
  18. #include "riscv_usability.h"
  19. #include "rvv_mathfun.h"
  20. #endif // __riscv_vector
  21. #include "cpu.h"
  22. namespace ncnn {
  23. Softmax_riscv::Softmax_riscv()
  24. {
  25. #if __riscv_vector
  26. support_packing = true;
  27. #endif // __riscv_vector
  28. }
  29. #if __riscv_vector
  30. #if __riscv_xtheadvector
  31. // FIXME inline causes illegal instruction :(
  32. __attribute__((noinline))
  33. #endif // __riscv_xtheadvector
  34. static vfloat32m8_t
  35. reset_tails(vfloat32m8_t x, size_t vl, float v)
  36. {
  37. const size_t vlm8 = __riscv_vsetvlmax_e32m8();
  38. vbool4_t _vl_mask = __riscv_vmsgeu_vx_u32m8_b4(__riscv_vid_v_u32m8(vlm8), vl, vlm8);
  39. x = __riscv_vfmerge_vfm_f32m8(x, v, _vl_mask, vlm8);
  40. return x;
  41. }
  42. #endif // __riscv_vector
  43. static void softmax(float* _ptr, int elemcount, int elempack)
  44. {
  45. const int size = elemcount * elempack;
  46. // NCNN_LOGE("softmax %d %d %d", elemcount, elempack, size);
  47. #if __riscv_vector
  48. const int packn = csrr_vlenb() / 4;
  49. // reduce max
  50. vfloat32m8_t _max = __riscv_vfmv_v_f_f32m8(-FLT_MAX, __riscv_vsetvlmax_e32m8());
  51. {
  52. const float* ptr = _ptr;
  53. int n = size / __riscv_vsetvlmax_e32m8() * __riscv_vsetvlmax_e32m8();
  54. const size_t vl = __riscv_vsetvlmax_e32m8();
  55. while (n > 0)
  56. {
  57. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  58. _max = __riscv_vfmax_vv_f32m8(_max, _p, vl);
  59. ptr += vl;
  60. n -= vl;
  61. }
  62. int remain = size % __riscv_vsetvlmax_e32m8();
  63. if (remain > 0)
  64. {
  65. size_t vlr = __riscv_vsetvl_e32m8(remain);
  66. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vlr);
  67. #if __riscv_xtheadvector
  68. // xtheadvector does not support tail undisturbed policy
  69. _p = reset_tails(_p, vlr, -FLT_MAX);
  70. _max = __riscv_vfmax_vv_f32m8(_max, _p, vl);
  71. #else // __riscv_xtheadvector
  72. _max = __riscv_vfmax_vv_f32m8_tu(_max, _max, _p, vlr);
  73. #endif // __riscv_xtheadvector
  74. }
  75. }
  76. if (elempack == packn)
  77. {
  78. // reduce max n,n,n,n,n,n,n,n to n
  79. // broadcast n to n,n,n,n,n,n,n,n
  80. vfloat32m4_t _max0 = __riscv_vfmax_vv_f32m4(__riscv_vget_v_f32m8_f32m4(_max, 0), __riscv_vget_v_f32m8_f32m4(_max, 1), __riscv_vsetvlmax_e32m4());
  81. vfloat32m2_t _max2 = __riscv_vfmax_vv_f32m2(__riscv_vget_v_f32m4_f32m2(_max0, 0), __riscv_vget_v_f32m4_f32m2(_max0, 1), __riscv_vsetvlmax_e32m2());
  82. vfloat32m1_t _max4 = __riscv_vfmax_vv_f32m1(__riscv_vget_v_f32m2_f32m1(_max2, 0), __riscv_vget_v_f32m2_f32m1(_max2, 1), __riscv_vsetvlmax_e32m1());
  83. _max = __riscv_vcreate_v_f32m1_f32m8(_max4, _max4, _max4, _max4, _max4, _max4, _max4, _max4);
  84. }
  85. if (elempack == 1)
  86. {
  87. // reduce max n,n,n,n,n,n,n,n to 1
  88. // broadcast 1 to n,n,n,n,n,n,n,n
  89. vfloat32m1_t _max0 = __riscv_vfmv_s_f_f32m1(-FLT_MAX, __riscv_vsetvlmax_e32m1());
  90. _max0 = __riscv_vfredmax_vs_f32m8_f32m1(_max, _max0, __riscv_vsetvlmax_e32m8());
  91. _max = __riscv_vset_v_f32m1_f32m8(_max, 0, _max0);
  92. _max = __riscv_vrgather_vx_f32m8(_max, 0, __riscv_vsetvlmax_e32m8());
  93. }
  94. // reduce exp(x - max)
  95. vfloat32m8_t _sum = __riscv_vfmv_v_f_f32m8(0.f, __riscv_vsetvlmax_e32m8());
  96. {
  97. float* ptr = _ptr;
  98. int n = size / __riscv_vsetvlmax_e32m8() * __riscv_vsetvlmax_e32m8();
  99. const size_t vl = __riscv_vsetvlmax_e32m8();
  100. while (n > 0)
  101. {
  102. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  103. _p = __riscv_vfsub_vv_f32m8(_p, _max, vl);
  104. _p = exp_ps(_p, vl);
  105. __riscv_vse32_v_f32m8(ptr, _p, vl);
  106. _sum = __riscv_vfadd_vv_f32m8(_sum, _p, vl);
  107. ptr += vl;
  108. n -= vl;
  109. }
  110. int remain = size % __riscv_vsetvlmax_e32m8();
  111. if (remain > 0)
  112. {
  113. size_t vlr = __riscv_vsetvl_e32m8(remain);
  114. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vlr);
  115. _p = __riscv_vfsub_vv_f32m8(_p, _max, vlr);
  116. _p = exp_ps(_p, vlr);
  117. __riscv_vse32_v_f32m8(ptr, _p, vlr);
  118. #if __riscv_xtheadvector
  119. // xtheadvector does not support tail undisturbed policy
  120. _p = reset_tails(_p, vlr, 0.f);
  121. _sum = __riscv_vfadd_vv_f32m8(_sum, _p, vl);
  122. #else // __riscv_xtheadvector
  123. _sum = __riscv_vfadd_vv_f32m8_tu(_sum, _sum, _p, vlr);
  124. #endif // __riscv_xtheadvector
  125. }
  126. }
  127. if (elempack == packn)
  128. {
  129. // reduce sum n,n,n,n,n,n,n,n to n
  130. // broadcast n to n,n,n,n,n,n,n,n
  131. vfloat32m4_t _sum0 = __riscv_vfadd_vv_f32m4(__riscv_vget_v_f32m8_f32m4(_sum, 0), __riscv_vget_v_f32m8_f32m4(_sum, 1), __riscv_vsetvlmax_e32m4());
  132. vfloat32m2_t _sum2 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(_sum0, 0), __riscv_vget_v_f32m4_f32m2(_sum0, 1), __riscv_vsetvlmax_e32m2());
  133. vfloat32m1_t _sum4 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(_sum2, 0), __riscv_vget_v_f32m2_f32m1(_sum2, 1), __riscv_vsetvlmax_e32m1());
  134. _sum = __riscv_vcreate_v_f32m1_f32m8(_sum4, _sum4, _sum4, _sum4, _sum4, _sum4, _sum4, _sum4);
  135. }
  136. if (elempack == 1)
  137. {
  138. // reduce sum n,n,n,n,n,n,n,n to 1
  139. // broadcast 1 to n,n,n,n,n,n,n,n
  140. vfloat32m1_t _sum0 = __riscv_vfmv_s_f_f32m1(0.f, __riscv_vsetvlmax_e32m1());
  141. _sum0 = __riscv_vfredusum_vs_f32m8_f32m1(_sum, _sum0, __riscv_vsetvlmax_e32m8());
  142. _sum = __riscv_vset_v_f32m1_f32m8(_sum, 0, _sum0);
  143. _sum = __riscv_vrgather_vx_f32m8(_sum, 0, __riscv_vsetvlmax_e32m8());
  144. }
  145. _sum = __riscv_vfrdiv_vf_f32m8(_sum, 1.f, __riscv_vsetvlmax_e32m8());
  146. // div sum
  147. {
  148. float* ptr = _ptr;
  149. int n = size;
  150. while (n > 0)
  151. {
  152. size_t vl = __riscv_vsetvl_e32m8(n);
  153. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  154. _p = __riscv_vfmul_vv_f32m8(_p, _sum, vl);
  155. __riscv_vse32_v_f32m8(ptr, _p, vl);
  156. n -= vl;
  157. ptr += vl;
  158. }
  159. }
  160. #else // __riscv_vector
  161. float max = -FLT_MAX;
  162. {
  163. const float* ptr = _ptr;
  164. for (int i = 0; i < size; i++)
  165. {
  166. max = std::max(max, *ptr++);
  167. }
  168. }
  169. // reduce exp(x - max)
  170. float sum = 0.f;
  171. {
  172. float* ptr = _ptr;
  173. for (int i = 0; i < size; i++)
  174. {
  175. float v = expf(*ptr - max);
  176. *ptr = v;
  177. sum += v;
  178. ptr++;
  179. }
  180. }
  181. sum = 1.f / sum;
  182. // div sum
  183. {
  184. float* ptr = _ptr;
  185. for (int i = 0; i < size; i++)
  186. {
  187. *ptr++ *= sum;
  188. }
  189. }
  190. #endif // __riscv_vector
  191. }
  192. #if __riscv_vector
  193. static void softmax_packn(float* _ptr, int elemcount, int stride, int size1, float* _maxptr, float* _sumptr)
  194. {
  195. const size_t vlm8 = __riscv_vsetvlmax_e32m8();
  196. const size_t vlm4 = __riscv_vsetvlmax_e32m4();
  197. const size_t vlm2 = __riscv_vsetvlmax_e32m2();
  198. const size_t vlm1 = __riscv_vsetvlmax_e32m1();
  199. // reduce max
  200. for (int i = 0; i < elemcount; i++)
  201. {
  202. const float* ptr = _ptr + i * stride;
  203. float* maxptr = _maxptr;
  204. int j = 0;
  205. for (; j + 7 < size1; j += 8)
  206. {
  207. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vlm8);
  208. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  209. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  210. vfloat32m1_t _m2 = __riscv_vfmv_v_f_f32m1(maxptr[2], vlm1);
  211. vfloat32m1_t _m3 = __riscv_vfmv_v_f_f32m1(maxptr[3], vlm1);
  212. vfloat32m1_t _m4 = __riscv_vfmv_v_f_f32m1(maxptr[4], vlm1);
  213. vfloat32m1_t _m5 = __riscv_vfmv_v_f_f32m1(maxptr[5], vlm1);
  214. vfloat32m1_t _m6 = __riscv_vfmv_v_f_f32m1(maxptr[6], vlm1);
  215. vfloat32m1_t _m7 = __riscv_vfmv_v_f_f32m1(maxptr[7], vlm1);
  216. vfloat32m1_t _max0 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 0), _m0, vlm1);
  217. vfloat32m1_t _max1 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 1), _m1, vlm1);
  218. vfloat32m1_t _max2 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 2), _m2, vlm1);
  219. vfloat32m1_t _max3 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 3), _m3, vlm1);
  220. vfloat32m1_t _max4 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 4), _m4, vlm1);
  221. vfloat32m1_t _max5 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 5), _m5, vlm1);
  222. vfloat32m1_t _max6 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 6), _m6, vlm1);
  223. vfloat32m1_t _max7 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 7), _m7, vlm1);
  224. maxptr[0] = __riscv_vfmv_f_s_f32m1_f32(_max0);
  225. maxptr[1] = __riscv_vfmv_f_s_f32m1_f32(_max1);
  226. maxptr[2] = __riscv_vfmv_f_s_f32m1_f32(_max2);
  227. maxptr[3] = __riscv_vfmv_f_s_f32m1_f32(_max3);
  228. maxptr[4] = __riscv_vfmv_f_s_f32m1_f32(_max4);
  229. maxptr[5] = __riscv_vfmv_f_s_f32m1_f32(_max5);
  230. maxptr[6] = __riscv_vfmv_f_s_f32m1_f32(_max6);
  231. maxptr[7] = __riscv_vfmv_f_s_f32m1_f32(_max7);
  232. ptr += vlm8;
  233. maxptr += 8;
  234. }
  235. for (; j + 3 < size1; j += 4)
  236. {
  237. vfloat32m4_t _p = __riscv_vle32_v_f32m4(ptr, vlm4);
  238. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  239. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  240. vfloat32m1_t _m2 = __riscv_vfmv_v_f_f32m1(maxptr[2], vlm1);
  241. vfloat32m1_t _m3 = __riscv_vfmv_v_f_f32m1(maxptr[3], vlm1);
  242. vfloat32m1_t _max0 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 0), _m0, vlm1);
  243. vfloat32m1_t _max1 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 1), _m1, vlm1);
  244. vfloat32m1_t _max2 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 2), _m2, vlm1);
  245. vfloat32m1_t _max3 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 3), _m3, vlm1);
  246. maxptr[0] = __riscv_vfmv_f_s_f32m1_f32(_max0);
  247. maxptr[1] = __riscv_vfmv_f_s_f32m1_f32(_max1);
  248. maxptr[2] = __riscv_vfmv_f_s_f32m1_f32(_max2);
  249. maxptr[3] = __riscv_vfmv_f_s_f32m1_f32(_max3);
  250. ptr += vlm4;
  251. maxptr += 4;
  252. }
  253. for (; j + 1 < size1; j += 2)
  254. {
  255. vfloat32m2_t _p = __riscv_vle32_v_f32m2(ptr, vlm2);
  256. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  257. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  258. vfloat32m1_t _max0 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m2_f32m1(_p, 0), _m0, vlm1);
  259. vfloat32m1_t _max1 = __riscv_vfredmax_vs_f32m1_f32m1(__riscv_vget_v_f32m2_f32m1(_p, 1), _m1, vlm1);
  260. maxptr[0] = __riscv_vfmv_f_s_f32m1_f32(_max0);
  261. maxptr[1] = __riscv_vfmv_f_s_f32m1_f32(_max1);
  262. ptr += vlm2;
  263. maxptr += 2;
  264. }
  265. for (; j < size1; j++)
  266. {
  267. vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr, vlm1);
  268. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(*maxptr, vlm1);
  269. _p = __riscv_vfredmax_vs_f32m1_f32m1(_p, _m0, vlm1);
  270. *maxptr = __riscv_vfmv_f_s_f32m1_f32(_p);
  271. ptr += vlm1;
  272. maxptr++;
  273. }
  274. }
  275. // reduce exp(x - max)
  276. for (int i = 0; i < elemcount; i++)
  277. {
  278. float* ptr = _ptr + i * stride;
  279. const float* maxptr = _maxptr;
  280. float* sumptr = _sumptr;
  281. int j = 0;
  282. for (; j + 7 < size1; j += 8)
  283. {
  284. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vlm8);
  285. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  286. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  287. vfloat32m1_t _m2 = __riscv_vfmv_v_f_f32m1(maxptr[2], vlm1);
  288. vfloat32m1_t _m3 = __riscv_vfmv_v_f_f32m1(maxptr[3], vlm1);
  289. vfloat32m1_t _m4 = __riscv_vfmv_v_f_f32m1(maxptr[4], vlm1);
  290. vfloat32m1_t _m5 = __riscv_vfmv_v_f_f32m1(maxptr[5], vlm1);
  291. vfloat32m1_t _m6 = __riscv_vfmv_v_f_f32m1(maxptr[6], vlm1);
  292. vfloat32m1_t _m7 = __riscv_vfmv_v_f_f32m1(maxptr[7], vlm1);
  293. vfloat32m8_t _max = __riscv_vcreate_v_f32m1_f32m8(_m0, _m1, _m2, _m3, _m4, _m5, _m6, _m7);
  294. _p = exp_ps(__riscv_vfsub_vv_f32m8(_p, _max, vlm8), vlm8);
  295. __riscv_vse32_v_f32m8(ptr, _p, vlm8);
  296. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  297. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  298. vfloat32m1_t _s2 = __riscv_vfmv_v_f_f32m1(sumptr[2], vlm1);
  299. vfloat32m1_t _s3 = __riscv_vfmv_v_f_f32m1(sumptr[3], vlm1);
  300. vfloat32m1_t _s4 = __riscv_vfmv_v_f_f32m1(sumptr[4], vlm1);
  301. vfloat32m1_t _s5 = __riscv_vfmv_v_f_f32m1(sumptr[5], vlm1);
  302. vfloat32m1_t _s6 = __riscv_vfmv_v_f_f32m1(sumptr[6], vlm1);
  303. vfloat32m1_t _s7 = __riscv_vfmv_v_f_f32m1(sumptr[7], vlm1);
  304. vfloat32m1_t _sum0 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 0), _s0, vlm1);
  305. vfloat32m1_t _sum1 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 1), _s1, vlm1);
  306. vfloat32m1_t _sum2 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 2), _s2, vlm1);
  307. vfloat32m1_t _sum3 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 3), _s3, vlm1);
  308. vfloat32m1_t _sum4 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 4), _s4, vlm1);
  309. vfloat32m1_t _sum5 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 5), _s5, vlm1);
  310. vfloat32m1_t _sum6 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 6), _s6, vlm1);
  311. vfloat32m1_t _sum7 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m8_f32m1(_p, 7), _s7, vlm1);
  312. sumptr[0] = __riscv_vfmv_f_s_f32m1_f32(_sum0);
  313. sumptr[1] = __riscv_vfmv_f_s_f32m1_f32(_sum1);
  314. sumptr[2] = __riscv_vfmv_f_s_f32m1_f32(_sum2);
  315. sumptr[3] = __riscv_vfmv_f_s_f32m1_f32(_sum3);
  316. sumptr[4] = __riscv_vfmv_f_s_f32m1_f32(_sum4);
  317. sumptr[5] = __riscv_vfmv_f_s_f32m1_f32(_sum5);
  318. sumptr[6] = __riscv_vfmv_f_s_f32m1_f32(_sum6);
  319. sumptr[7] = __riscv_vfmv_f_s_f32m1_f32(_sum7);
  320. ptr += vlm8;
  321. maxptr += 8;
  322. sumptr += 8;
  323. }
  324. for (; j + 3 < size1; j += 4)
  325. {
  326. vfloat32m4_t _p = __riscv_vle32_v_f32m4(ptr, vlm4);
  327. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  328. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  329. vfloat32m1_t _m2 = __riscv_vfmv_v_f_f32m1(maxptr[2], vlm1);
  330. vfloat32m1_t _m3 = __riscv_vfmv_v_f_f32m1(maxptr[3], vlm1);
  331. vfloat32m4_t _max = __riscv_vcreate_v_f32m1_f32m4(_m0, _m1, _m2, _m3);
  332. _p = exp_ps(__riscv_vfsub_vv_f32m4(_p, _max, vlm4), vlm4);
  333. __riscv_vse32_v_f32m4(ptr, _p, vlm4);
  334. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  335. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  336. vfloat32m1_t _s2 = __riscv_vfmv_v_f_f32m1(sumptr[2], vlm1);
  337. vfloat32m1_t _s3 = __riscv_vfmv_v_f_f32m1(sumptr[3], vlm1);
  338. vfloat32m1_t _sum0 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 0), _s0, vlm1);
  339. vfloat32m1_t _sum1 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 1), _s1, vlm1);
  340. vfloat32m1_t _sum2 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 2), _s2, vlm1);
  341. vfloat32m1_t _sum3 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m4_f32m1(_p, 3), _s3, vlm1);
  342. sumptr[0] = __riscv_vfmv_f_s_f32m1_f32(_sum0);
  343. sumptr[1] = __riscv_vfmv_f_s_f32m1_f32(_sum1);
  344. sumptr[2] = __riscv_vfmv_f_s_f32m1_f32(_sum2);
  345. sumptr[3] = __riscv_vfmv_f_s_f32m1_f32(_sum3);
  346. ptr += vlm4;
  347. maxptr += 4;
  348. sumptr += 4;
  349. }
  350. for (; j + 1 < size1; j += 2)
  351. {
  352. vfloat32m2_t _p = __riscv_vle32_v_f32m2(ptr, vlm2);
  353. vfloat32m1_t _m0 = __riscv_vfmv_v_f_f32m1(maxptr[0], vlm1);
  354. vfloat32m1_t _m1 = __riscv_vfmv_v_f_f32m1(maxptr[1], vlm1);
  355. vfloat32m2_t _max = __riscv_vcreate_v_f32m1_f32m2(_m0, _m1);
  356. _p = exp_ps(__riscv_vfsub_vv_f32m2(_p, _max, vlm2), vlm2);
  357. __riscv_vse32_v_f32m2(ptr, _p, vlm2);
  358. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  359. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  360. vfloat32m1_t _sum0 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m2_f32m1(_p, 0), _s0, vlm1);
  361. vfloat32m1_t _sum1 = __riscv_vfredusum_vs_f32m1_f32m1(__riscv_vget_v_f32m2_f32m1(_p, 1), _s1, vlm1);
  362. sumptr[0] = __riscv_vfmv_f_s_f32m1_f32(_sum0);
  363. sumptr[1] = __riscv_vfmv_f_s_f32m1_f32(_sum1);
  364. ptr += vlm2;
  365. maxptr += 2;
  366. sumptr += 2;
  367. }
  368. for (; j < size1; j++)
  369. {
  370. vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr, vlm1);
  371. _p = exp_ps(__riscv_vfsub_vf_f32m1(_p, *maxptr, vlm1), vlm1);
  372. __riscv_vse32_v_f32m1(ptr, _p, vlm1);
  373. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(*sumptr, vlm1);
  374. vfloat32m1_t _sum = __riscv_vfredusum_vs_f32m1_f32m1(_p, _s0, vlm1);
  375. *sumptr = __riscv_vfmv_f_s_f32m1_f32(_sum);
  376. ptr += vlm1;
  377. maxptr++;
  378. sumptr++;
  379. }
  380. }
  381. {
  382. float* sumptr = _sumptr;
  383. int n = size1;
  384. while (n > 0)
  385. {
  386. size_t vl = __riscv_vsetvl_e32m8(n);
  387. vfloat32m8_t _sum = __riscv_vle32_v_f32m8(sumptr, vl);
  388. _sum = __riscv_vfrdiv_vf_f32m8(_sum, 1.f, vl);
  389. __riscv_vse32_v_f32m8(sumptr, _sum, vl);
  390. n -= vl;
  391. sumptr += vl;
  392. }
  393. }
  394. // div sum
  395. for (int i = 0; i < elemcount; i++)
  396. {
  397. float* ptr = _ptr + i * stride;
  398. const float* sumptr = _sumptr;
  399. int j = 0;
  400. for (; j + 7 < size1; j += 8)
  401. {
  402. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vlm8);
  403. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  404. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  405. vfloat32m1_t _s2 = __riscv_vfmv_v_f_f32m1(sumptr[2], vlm1);
  406. vfloat32m1_t _s3 = __riscv_vfmv_v_f_f32m1(sumptr[3], vlm1);
  407. vfloat32m1_t _s4 = __riscv_vfmv_v_f_f32m1(sumptr[4], vlm1);
  408. vfloat32m1_t _s5 = __riscv_vfmv_v_f_f32m1(sumptr[5], vlm1);
  409. vfloat32m1_t _s6 = __riscv_vfmv_v_f_f32m1(sumptr[6], vlm1);
  410. vfloat32m1_t _s7 = __riscv_vfmv_v_f_f32m1(sumptr[7], vlm1);
  411. vfloat32m8_t _sum = __riscv_vcreate_v_f32m1_f32m8(_s0, _s1, _s2, _s3, _s4, _s5, _s6, _s7);
  412. _p = __riscv_vfmul_vv_f32m8(_p, _sum, vlm8);
  413. __riscv_vse32_v_f32m8(ptr, _p, vlm8);
  414. ptr += vlm8;
  415. sumptr += 8;
  416. }
  417. for (; j + 3 < size1; j += 4)
  418. {
  419. vfloat32m4_t _p = __riscv_vle32_v_f32m4(ptr, vlm4);
  420. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  421. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  422. vfloat32m1_t _s2 = __riscv_vfmv_v_f_f32m1(sumptr[2], vlm1);
  423. vfloat32m1_t _s3 = __riscv_vfmv_v_f_f32m1(sumptr[3], vlm1);
  424. vfloat32m4_t _sum = __riscv_vcreate_v_f32m1_f32m4(_s0, _s1, _s2, _s3);
  425. _p = __riscv_vfmul_vv_f32m4(_p, _sum, vlm4);
  426. __riscv_vse32_v_f32m4(ptr, _p, vlm4);
  427. ptr += vlm4;
  428. sumptr += 4;
  429. }
  430. for (; j + 1 < size1; j += 2)
  431. {
  432. vfloat32m2_t _p = __riscv_vle32_v_f32m2(ptr, vlm2);
  433. vfloat32m1_t _s0 = __riscv_vfmv_v_f_f32m1(sumptr[0], vlm1);
  434. vfloat32m1_t _s1 = __riscv_vfmv_v_f_f32m1(sumptr[1], vlm1);
  435. vfloat32m2_t _sum = __riscv_vcreate_v_f32m1_f32m2(_s0, _s1);
  436. _p = __riscv_vfmul_vv_f32m2(_p, _sum, vlm2);
  437. __riscv_vse32_v_f32m2(ptr, _p, vlm2);
  438. ptr += vlm2;
  439. sumptr += 2;
  440. }
  441. for (; j < size1; j++)
  442. {
  443. vfloat32m1_t _p = __riscv_vle32_v_f32m1(ptr, vlm1);
  444. _p = __riscv_vfmul_vf_f32m1(_p, *sumptr, vlm1);
  445. __riscv_vse32_v_f32m1(ptr, _p, vlm1);
  446. ptr += vlm1;
  447. sumptr++;
  448. }
  449. }
  450. }
  451. #endif // __riscv_vector
  452. static void softmax_pack1(float* _ptr, int elemcount, int stride, int size1, float* _maxptr, float* _sumptr)
  453. {
  454. // reduce max
  455. for (int i = 0; i < elemcount; i++)
  456. {
  457. const float* ptr = _ptr + i * stride;
  458. float* maxptr = _maxptr;
  459. #if __riscv_vector
  460. int n = size1;
  461. while (n > 0)
  462. {
  463. size_t vl = __riscv_vsetvl_e32m8(n);
  464. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  465. vfloat32m8_t _max = __riscv_vle32_v_f32m8(maxptr, vl);
  466. _max = __riscv_vfmax_vv_f32m8(_max, _p, vl);
  467. __riscv_vse32_v_f32m8(maxptr, _max, vl);
  468. n -= vl;
  469. ptr += vl;
  470. maxptr += vl;
  471. }
  472. #else // __riscv_vector
  473. for (int j = 0; j < size1; j++)
  474. {
  475. *maxptr = std::max(*maxptr, *ptr);
  476. ptr++;
  477. maxptr++;
  478. }
  479. #endif // __riscv_vector
  480. }
  481. // reduce exp(x - max)
  482. for (int i = 0; i < elemcount; i++)
  483. {
  484. float* ptr = _ptr + i * stride;
  485. const float* maxptr = _maxptr;
  486. float* sumptr = _sumptr;
  487. #if __riscv_vector
  488. int n = size1;
  489. while (n > 0)
  490. {
  491. size_t vl = __riscv_vsetvl_e32m8(n);
  492. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  493. vfloat32m8_t _max = __riscv_vle32_v_f32m8(maxptr, vl);
  494. vfloat32m8_t _sum = __riscv_vle32_v_f32m8(sumptr, vl);
  495. _p = __riscv_vfsub_vv_f32m8(_p, _max, vl);
  496. _p = exp_ps(_p, vl);
  497. __riscv_vse32_v_f32m8(ptr, _p, vl);
  498. _sum = __riscv_vfadd_vv_f32m8(_sum, _p, vl);
  499. __riscv_vse32_v_f32m8(sumptr, _sum, vl);
  500. n -= vl;
  501. ptr += vl;
  502. maxptr += vl;
  503. sumptr += vl;
  504. }
  505. #else // __riscv_vector
  506. for (int j = 0; j < size1; j++)
  507. {
  508. float v = expf(*ptr - *maxptr);
  509. *ptr = v;
  510. *sumptr += v;
  511. ptr++;
  512. maxptr++;
  513. sumptr++;
  514. }
  515. #endif // __riscv_vector
  516. }
  517. {
  518. float* sumptr = _sumptr;
  519. #if __riscv_vector
  520. int n = size1;
  521. while (n > 0)
  522. {
  523. size_t vl = __riscv_vsetvl_e32m8(n);
  524. vfloat32m8_t _sum = __riscv_vle32_v_f32m8(sumptr, vl);
  525. _sum = __riscv_vfrdiv_vf_f32m8(_sum, 1.f, vl);
  526. __riscv_vse32_v_f32m8(sumptr, _sum, vl);
  527. n -= vl;
  528. sumptr += vl;
  529. }
  530. #else // __riscv_vector
  531. for (int j = 0; j < size1; j++)
  532. {
  533. *sumptr = 1.f / *sumptr;
  534. sumptr++;
  535. }
  536. #endif // __riscv_vector
  537. }
  538. // div sum
  539. for (int i = 0; i < elemcount; i++)
  540. {
  541. float* ptr = _ptr + i * stride;
  542. const float* sumptr = _sumptr;
  543. #if __riscv_vector
  544. int n = size1;
  545. while (n > 0)
  546. {
  547. size_t vl = __riscv_vsetvl_e32m8(n);
  548. vfloat32m8_t _p = __riscv_vle32_v_f32m8(ptr, vl);
  549. vfloat32m8_t _sum = __riscv_vle32_v_f32m8(sumptr, vl);
  550. _p = __riscv_vfmul_vv_f32m8(_p, _sum, vl);
  551. __riscv_vse32_v_f32m8(ptr, _p, vl);
  552. n -= vl;
  553. ptr += vl;
  554. sumptr += vl;
  555. }
  556. #else // __riscv_vector
  557. for (int j = 0; j < size1; j++)
  558. {
  559. *ptr *= *sumptr;
  560. ptr++;
  561. sumptr++;
  562. }
  563. #endif // __riscv_vector
  564. }
  565. }
  566. static void softmax(float* _ptr, int elemcount, int elempack, int stride, int size1, float* _maxptr, float* _sumptr)
  567. {
  568. // reduce max
  569. {
  570. float* maxptr = _maxptr;
  571. #if __riscv_vector
  572. vfloat32m8_t _negmax = __riscv_vfmv_v_f_f32m8(-FLT_MAX, __riscv_vsetvlmax_e32m8());
  573. int n = size1;
  574. while (n > 0)
  575. {
  576. size_t vl = __riscv_vsetvl_e32m8(n);
  577. __riscv_vse32_v_f32m8(maxptr, _negmax, vl);
  578. n -= vl;
  579. maxptr += vl;
  580. }
  581. #else // __riscv_vector
  582. for (int j = 0; j < size1; j++)
  583. {
  584. *maxptr++ = -FLT_MAX;
  585. }
  586. #endif // __riscv_vector
  587. }
  588. // reduce exp(x - max)
  589. {
  590. float* sumptr = _sumptr;
  591. #if __riscv_vector
  592. vfloat32m8_t _zero = __riscv_vfmv_v_f_f32m8(0.f, __riscv_vsetvlmax_e32m8());
  593. int n = size1;
  594. while (n > 0)
  595. {
  596. size_t vl = __riscv_vsetvl_e32m8(n);
  597. __riscv_vse32_v_f32m8(sumptr, _zero, vl);
  598. n -= vl;
  599. sumptr += vl;
  600. }
  601. #else // __riscv_vector
  602. for (int j = 0; j < size1; j++)
  603. {
  604. *sumptr++ = 0.f;
  605. }
  606. #endif // __riscv_vector
  607. }
  608. #if __riscv_vector
  609. const int packn = csrr_vlenb() / 4;
  610. if (elempack == packn)
  611. {
  612. softmax_packn(_ptr, elemcount, stride, size1, _maxptr, _sumptr);
  613. }
  614. #endif // __riscv_vector
  615. if (elempack == 1)
  616. {
  617. softmax_pack1(_ptr, elemcount, stride, size1, _maxptr, _sumptr);
  618. }
  619. }
  620. int Softmax_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
  621. {
  622. const int dims = bottom_top_blob.dims;
  623. const int w = bottom_top_blob.w;
  624. const int h = bottom_top_blob.h;
  625. const int d = bottom_top_blob.d;
  626. const int channels = bottom_top_blob.c;
  627. const int elempack = bottom_top_blob.elempack;
  628. const int positive_axis = axis < 0 ? dims + axis : axis;
  629. if (dims == 1) // positive_axis == 0
  630. {
  631. float* ptr = bottom_top_blob;
  632. const int size = w * elempack;
  633. softmax(ptr, size, 1);
  634. }
  635. if (dims == 2 && positive_axis == 0)
  636. {
  637. const int size = w;
  638. const int sizen = (size + (opt.num_threads - 1)) / opt.num_threads;
  639. const int stride = w * elempack;
  640. Mat maxsum(sizen, 2, opt.num_threads, 4u, opt.workspace_allocator);
  641. if (maxsum.empty())
  642. return -100;
  643. const int nn_size = size / sizen;
  644. #pragma omp parallel for num_threads(opt.num_threads)
  645. for (int ii = 0; ii < nn_size; ii++)
  646. {
  647. const int i = ii * sizen;
  648. const int size1 = std::min(sizen, size - i);
  649. float* maxsumptr = maxsum.channel(get_omp_thread_num());
  650. float* maxptr = maxsumptr;
  651. float* sumptr = maxptr + sizen;
  652. float* ptr = (float*)bottom_top_blob + i * elempack;
  653. softmax(ptr, h, elempack, stride, size1, maxptr, sumptr);
  654. }
  655. }
  656. if (dims == 2 && positive_axis == 1)
  657. {
  658. #pragma omp parallel for num_threads(opt.num_threads)
  659. for (int i = 0; i < h; i++)
  660. {
  661. float* ptr = bottom_top_blob.row(i);
  662. softmax(ptr, w, elempack);
  663. }
  664. }
  665. if ((dims == 3 || dims == 4) && positive_axis == 0)
  666. {
  667. const int size = w * h * d;
  668. const int sizen = (size + (opt.num_threads - 1)) / opt.num_threads;
  669. const int stride = bottom_top_blob.cstep * elempack;
  670. Mat maxsum(sizen, 2, opt.num_threads, 4u, opt.workspace_allocator);
  671. if (maxsum.empty())
  672. return -100;
  673. const int nn_size = size / sizen;
  674. #pragma omp parallel for num_threads(opt.num_threads)
  675. for (int ii = 0; ii < nn_size; ii++)
  676. {
  677. const int i = ii * sizen;
  678. const int size1 = std::min(sizen, size - i);
  679. float* maxsumptr = maxsum.channel(get_omp_thread_num());
  680. float* maxptr = maxsumptr;
  681. float* sumptr = maxptr + sizen;
  682. float* ptr = (float*)bottom_top_blob + i * elempack;
  683. softmax(ptr, channels, elempack, stride, size1, maxptr, sumptr);
  684. }
  685. }
  686. if ((dims == 3 && positive_axis == 1) || (dims == 4 && positive_axis == 2))
  687. {
  688. const int size = w * elempack;
  689. Mat maxsum(size, 2, opt.num_threads, 4u, opt.workspace_allocator);
  690. if (maxsum.empty())
  691. return -100;
  692. #pragma omp parallel for num_threads(opt.num_threads)
  693. for (int q = 0; q < channels; q++)
  694. {
  695. for (int i = 0; i < d; i++)
  696. {
  697. float* ptr = bottom_top_blob.channel(q).depth(i);
  698. float* maxsumptr = maxsum.channel(get_omp_thread_num());
  699. float* maxptr = maxsumptr;
  700. float* sumptr = maxptr + size;
  701. softmax(ptr, h, 1, size, size, maxptr, sumptr);
  702. }
  703. }
  704. }
  705. if (dims == 3 && positive_axis == 2)
  706. {
  707. #pragma omp parallel for num_threads(opt.num_threads)
  708. for (int q = 0; q < channels; q++)
  709. {
  710. float* ptr = bottom_top_blob.channel(q);
  711. for (int i = 0; i < h; i++)
  712. {
  713. softmax(ptr, w, elempack);
  714. ptr += w * elempack;
  715. }
  716. }
  717. }
  718. if (dims == 4 && positive_axis == 1)
  719. {
  720. const int size = w * h * elempack;
  721. Mat maxsum(size, 2, opt.num_threads, 4u, opt.workspace_allocator);
  722. if (maxsum.empty())
  723. return -100;
  724. #pragma omp parallel for num_threads(opt.num_threads)
  725. for (int q = 0; q < channels; q++)
  726. {
  727. float* ptr = bottom_top_blob.channel(q);
  728. float* maxsumptr = maxsum.channel(get_omp_thread_num());
  729. float* maxptr = maxsumptr;
  730. float* sumptr = maxptr + size;
  731. softmax(ptr, d, 1, size, size, maxptr, sumptr);
  732. }
  733. }
  734. if (dims == 4 && positive_axis == 3)
  735. {
  736. #pragma omp parallel for num_threads(opt.num_threads)
  737. for (int q = 0; q < channels; q++)
  738. {
  739. float* ptr = bottom_top_blob.channel(q);
  740. for (int i = 0; i < d; i++)
  741. {
  742. for (int j = 0; j < h; j++)
  743. {
  744. softmax(ptr, w, elempack);
  745. ptr += w * elempack;
  746. }
  747. }
  748. }
  749. }
  750. return 0;
  751. }
  752. } // namespace ncnn