You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

unaryop_mips.cpp 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "unaryop_mips.h"
  15. #include <math.h>
  16. #if __mips_msa
  17. #include <msa.h>
  18. #include "msa_mathfun.h"
  19. #endif // __mips_msa
  20. namespace ncnn {
  21. UnaryOp_mips::UnaryOp_mips()
  22. {
  23. #if __mips_msa
  24. support_packing = true;
  25. #endif // __mips_msa
  26. }
  27. #if __mips_msa
  28. template<typename Op>
  29. static int unary_op_inplace_pack4(Mat& a, const Option& opt)
  30. {
  31. Op op;
  32. int w = a.w;
  33. int h = a.h;
  34. int d = a.d;
  35. int channels = a.c;
  36. int size = w * h * d;
  37. #pragma omp parallel for num_threads(opt.num_threads)
  38. for (int q = 0; q < channels; q++)
  39. {
  40. float* ptr = a.channel(q);
  41. for (int i = 0; i < size; i++)
  42. {
  43. __builtin_prefetch(ptr + 32);
  44. v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
  45. _p = op(_p);
  46. __msa_st_w((v4i32)_p, ptr, 0);
  47. ptr += 4;
  48. }
  49. }
  50. return 0;
  51. }
  52. struct unary_op_abs_pack4
  53. {
  54. v4f32 operator()(const v4f32& x) const
  55. {
  56. return (v4f32)__msa_bclri_w((v4u32)x, 31);
  57. }
  58. };
  59. struct unary_op_neg_pack4
  60. {
  61. v4f32 operator()(const v4f32& x) const
  62. {
  63. return (v4f32)__msa_bnegi_w((v4u32)x, 31);
  64. }
  65. };
  66. struct unary_op_floor_pack4
  67. {
  68. v4f32 operator()(const v4f32& x) const
  69. {
  70. // TODO msa optimize
  71. float tmp[4];
  72. __msa_st_w((v4i32)x, tmp, 0);
  73. tmp[0] = floor(tmp[0]);
  74. tmp[1] = floor(tmp[1]);
  75. tmp[2] = floor(tmp[2]);
  76. tmp[3] = floor(tmp[3]);
  77. return (v4f32)__msa_ld_w(tmp, 0);
  78. // int old_msacsr = __msa_cfcmsa_msacsr();
  79. // __msa_ctcmsa_msacsr(old_msacsr | 3); // round towards -inf
  80. // v4f32 y = __msa_frint_w(x);
  81. // __msa_ctcmsa_msacsr(old_msacsr);
  82. // return y;
  83. }
  84. };
  85. struct unary_op_ceil_pack4
  86. {
  87. v4f32 operator()(const v4f32& x) const
  88. {
  89. // TODO msa optimize
  90. float tmp[4];
  91. __msa_st_w((v4i32)x, tmp, 0);
  92. tmp[0] = ceil(tmp[0]);
  93. tmp[1] = ceil(tmp[1]);
  94. tmp[2] = ceil(tmp[2]);
  95. tmp[3] = ceil(tmp[3]);
  96. return (v4f32)__msa_ld_w(tmp, 0);
  97. // int old_msacsr = __msa_cfcmsa_msacsr();
  98. // __msa_ctcmsa_msacsr((old_msacsr | 3) ^ 1); // round towards +inf
  99. // v4f32 y = __msa_frint_w(x);
  100. // __msa_ctcmsa_msacsr(old_msacsr);
  101. // return y;
  102. }
  103. };
  104. struct unary_op_square_pack4
  105. {
  106. v4f32 operator()(const v4f32& x) const
  107. {
  108. return __msa_fmul_w(x, x);
  109. }
  110. };
  111. struct unary_op_sqrt_pack4
  112. {
  113. v4f32 operator()(const v4f32& x) const
  114. {
  115. return __msa_fsqrt_w(x);
  116. }
  117. };
  118. struct unary_op_rsqrt_pack4
  119. {
  120. v4f32 operator()(const v4f32& x) const
  121. {
  122. return __msa_frsqrt_w(x);
  123. }
  124. };
  125. struct unary_op_exp_pack4
  126. {
  127. v4f32 operator()(const v4f32& x) const
  128. {
  129. return exp_ps(x);
  130. }
  131. };
  132. struct unary_op_log_pack4
  133. {
  134. v4f32 operator()(const v4f32& x) const
  135. {
  136. return log_ps(x);
  137. }
  138. };
  139. struct unary_op_sin_pack4
  140. {
  141. v4f32 operator()(const v4f32& x) const
  142. {
  143. // TODO msa optimize
  144. float tmp[4];
  145. __msa_st_w((v4i32)x, tmp, 0);
  146. tmp[0] = sin(tmp[0]);
  147. tmp[1] = sin(tmp[1]);
  148. tmp[2] = sin(tmp[2]);
  149. tmp[3] = sin(tmp[3]);
  150. return (v4f32)__msa_ld_w(tmp, 0);
  151. }
  152. };
  153. struct unary_op_cos_pack4
  154. {
  155. v4f32 operator()(const v4f32& x) const
  156. {
  157. // TODO msa optimize
  158. float tmp[4];
  159. __msa_st_w((v4i32)x, tmp, 0);
  160. tmp[0] = cos(tmp[0]);
  161. tmp[1] = cos(tmp[1]);
  162. tmp[2] = cos(tmp[2]);
  163. tmp[3] = cos(tmp[3]);
  164. return (v4f32)__msa_ld_w(tmp, 0);
  165. }
  166. };
  167. struct unary_op_tan_pack4
  168. {
  169. v4f32 operator()(const v4f32& x) const
  170. {
  171. // TODO msa optimize
  172. float tmp[4];
  173. __msa_st_w((v4i32)x, tmp, 0);
  174. tmp[0] = tan(tmp[0]);
  175. tmp[1] = tan(tmp[1]);
  176. tmp[2] = tan(tmp[2]);
  177. tmp[3] = tan(tmp[3]);
  178. return (v4f32)__msa_ld_w(tmp, 0);
  179. }
  180. };
  181. struct unary_op_asin_pack4
  182. {
  183. v4f32 operator()(const v4f32& x) const
  184. {
  185. // TODO msa optimize
  186. float tmp[4];
  187. __msa_st_w((v4i32)x, tmp, 0);
  188. tmp[0] = asin(tmp[0]);
  189. tmp[1] = asin(tmp[1]);
  190. tmp[2] = asin(tmp[2]);
  191. tmp[3] = asin(tmp[3]);
  192. return (v4f32)__msa_ld_w(tmp, 0);
  193. }
  194. };
  195. struct unary_op_acos_pack4
  196. {
  197. v4f32 operator()(const v4f32& x) const
  198. {
  199. // TODO msa optimize
  200. float tmp[4];
  201. __msa_st_w((v4i32)x, tmp, 0);
  202. tmp[0] = acos(tmp[0]);
  203. tmp[1] = acos(tmp[1]);
  204. tmp[2] = acos(tmp[2]);
  205. tmp[3] = acos(tmp[3]);
  206. return (v4f32)__msa_ld_w(tmp, 0);
  207. }
  208. };
  209. struct unary_op_atan_pack4
  210. {
  211. v4f32 operator()(const v4f32& x) const
  212. {
  213. // TODO msa optimize
  214. float tmp[4];
  215. __msa_st_w((v4i32)x, tmp, 0);
  216. tmp[0] = atan(tmp[0]);
  217. tmp[1] = atan(tmp[1]);
  218. tmp[2] = atan(tmp[2]);
  219. tmp[3] = atan(tmp[3]);
  220. return (v4f32)__msa_ld_w(tmp, 0);
  221. }
  222. };
  223. struct unary_op_reciprocal_pack4
  224. {
  225. v4f32 operator()(const v4f32& x) const
  226. {
  227. return __msa_frcp_w(x);
  228. }
  229. };
  230. struct unary_op_tanh_pack4
  231. {
  232. v4f32 operator()(const v4f32& x) const
  233. {
  234. return tanh_ps(x);
  235. }
  236. };
  237. #endif // __mips_msa
  238. int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
  239. {
  240. int elempack = bottom_top_blob.elempack;
  241. #if __mips_msa
  242. if (elempack == 4)
  243. {
  244. if (op_type == Operation_ABS)
  245. return unary_op_inplace_pack4<unary_op_abs_pack4>(bottom_top_blob, opt);
  246. if (op_type == Operation_NEG)
  247. return unary_op_inplace_pack4<unary_op_neg_pack4>(bottom_top_blob, opt);
  248. if (op_type == Operation_FLOOR)
  249. return unary_op_inplace_pack4<unary_op_floor_pack4>(bottom_top_blob, opt);
  250. if (op_type == Operation_CEIL)
  251. return unary_op_inplace_pack4<unary_op_ceil_pack4>(bottom_top_blob, opt);
  252. if (op_type == Operation_SQUARE)
  253. return unary_op_inplace_pack4<unary_op_square_pack4>(bottom_top_blob, opt);
  254. if (op_type == Operation_SQRT)
  255. return unary_op_inplace_pack4<unary_op_sqrt_pack4>(bottom_top_blob, opt);
  256. if (op_type == Operation_RSQRT)
  257. return unary_op_inplace_pack4<unary_op_rsqrt_pack4>(bottom_top_blob, opt);
  258. if (op_type == Operation_EXP)
  259. return unary_op_inplace_pack4<unary_op_exp_pack4>(bottom_top_blob, opt);
  260. if (op_type == Operation_LOG)
  261. return unary_op_inplace_pack4<unary_op_log_pack4>(bottom_top_blob, opt);
  262. if (op_type == Operation_SIN)
  263. return unary_op_inplace_pack4<unary_op_sin_pack4>(bottom_top_blob, opt);
  264. if (op_type == Operation_COS)
  265. return unary_op_inplace_pack4<unary_op_cos_pack4>(bottom_top_blob, opt);
  266. if (op_type == Operation_TAN)
  267. return unary_op_inplace_pack4<unary_op_tan_pack4>(bottom_top_blob, opt);
  268. if (op_type == Operation_ASIN)
  269. return unary_op_inplace_pack4<unary_op_asin_pack4>(bottom_top_blob, opt);
  270. if (op_type == Operation_ACOS)
  271. return unary_op_inplace_pack4<unary_op_acos_pack4>(bottom_top_blob, opt);
  272. if (op_type == Operation_ATAN)
  273. return unary_op_inplace_pack4<unary_op_atan_pack4>(bottom_top_blob, opt);
  274. if (op_type == Operation_RECIPROCAL)
  275. return unary_op_inplace_pack4<unary_op_reciprocal_pack4>(bottom_top_blob, opt);
  276. if (op_type == Operation_TANH)
  277. return unary_op_inplace_pack4<unary_op_tanh_pack4>(bottom_top_blob, opt);
  278. }
  279. #endif // __mips_msa
  280. return UnaryOp::forward_inplace(bottom_top_blob, opt);
  281. }
  282. } // namespace ncnn