You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

expression.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. // Copyright 2025 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "expression.h"
  4. namespace ncnn {
  5. int count_expression_blobs(const std::string& expr)
  6. {
  7. int count = 0;
  8. std::string t;
  9. for (size_t i = 0; i < expr.size(); i++)
  10. {
  11. char ch = expr[i];
  12. if (ch == '(' || ch == ')' || ch == ',')
  13. {
  14. if (!t.empty())
  15. {
  16. if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c'))
  17. {
  18. int blob_index = t[0] - '0';
  19. count = std::max(count, blob_index + 1);
  20. }
  21. t.clear();
  22. }
  23. }
  24. else
  25. {
  26. #if NCNN_SIMPLESTL
  27. t.resize(t.size() + 1);
  28. t[t.size() - 1] = ch;
  29. #else
  30. t += ch;
  31. #endif
  32. }
  33. }
  34. if (!t.empty())
  35. {
  36. if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c'))
  37. {
  38. int blob_index = t[0] - '0';
  39. count = std::max(count, blob_index + 1);
  40. }
  41. }
  42. return count;
  43. }
  44. struct typed_value
  45. {
  46. int type; // 0=i 1=f
  47. union
  48. {
  49. int i;
  50. float f;
  51. };
  52. typed_value()
  53. : type(0), i(0)
  54. {
  55. }
  56. typed_value(int _i)
  57. : type(0), i(_i)
  58. {
  59. }
  60. typed_value(float _f)
  61. : type(1), f(_f)
  62. {
  63. }
  64. int to_int()
  65. {
  66. if (type == 0)
  67. return i;
  68. // trunc by default
  69. return (int)f;
  70. }
  71. };
  72. int eval_list_expression(const std::string& expr, const std::vector<Mat>& blobs, std::vector<int>& outlist)
  73. {
  74. // /(0w,2),*(0h,2),0c
  75. // split by , ( )
  76. //
  77. // /
  78. // 0w
  79. // 2
  80. // -------------------
  81. // *
  82. // 0h
  83. // 2
  84. // -------------------
  85. // 0c
  86. // -------------------
  87. // split by , ( )
  88. // split into tokens
  89. std::vector<std::string> tokens;
  90. {
  91. std::string t;
  92. for (size_t i = 0; i < expr.size(); i++)
  93. {
  94. char ch = expr[i];
  95. if (ch == '(' || ch == ')' || ch == ',')
  96. {
  97. if (!t.empty())
  98. {
  99. tokens.push_back(t);
  100. t.clear();
  101. }
  102. }
  103. else
  104. {
  105. #if NCNN_SIMPLESTL
  106. t.resize(t.size() + 1);
  107. t[t.size() - 1] = ch;
  108. #else
  109. t += ch;
  110. #endif
  111. }
  112. }
  113. if (!t.empty())
  114. {
  115. tokens.push_back(t);
  116. }
  117. }
  118. // / 0w 2 * 0h 2 0c
  119. // scan and stack
  120. std::stack<typed_value> exprstack;
  121. for (int i = (int)tokens.size() - 1; i >= 0; i--)
  122. {
  123. const std::string& t = tokens[i];
  124. // NCNN_LOGE("t = %s", t.c_str());
  125. // + - * / 0w 0h 0d 0c 12345
  126. if (t.size() == 2 && (t[0] >= '0' && t[0] <= '9') && (t[1] == 'w' || t[1] == 'h' || t[1] == 'd' || t[1] == 'c'))
  127. {
  128. size_t blob_index = t[0] - '0';
  129. if (blob_index >= blobs.size())
  130. {
  131. NCNN_LOGE("shape expression blob index %d out of bound!", (int)blob_index);
  132. return -1;
  133. }
  134. const Mat& blob = blobs[blob_index].shape();
  135. int size;
  136. if (t[1] == 'w')
  137. size = blob.w;
  138. else if (t[1] == 'h')
  139. size = blob.h;
  140. else if (t[1] == 'd')
  141. size = blob.d;
  142. else // if (t[1] == 'c')
  143. size = blob.c;
  144. // NCNN_LOGE("t = %s => %d", t.c_str(), size);
  145. exprstack.push(size);
  146. }
  147. else if (t == "+" || t == "-" || t == "*" || t == "//" || t == "max" || t == "min")
  148. {
  149. typed_value ta = exprstack.top();
  150. exprstack.pop();
  151. typed_value tb = exprstack.top();
  152. exprstack.pop();
  153. if (ta.type == 0 && tb.type == 0)
  154. {
  155. const int a = ta.i;
  156. const int b = tb.i;
  157. int r = 0;
  158. if (t == "+")
  159. {
  160. r = a + b;
  161. }
  162. else if (t == "-")
  163. {
  164. r = a - b;
  165. }
  166. else if (t == "*")
  167. {
  168. r = a * b;
  169. }
  170. else if (t == "//")
  171. {
  172. if (b == 0)
  173. {
  174. NCNN_LOGE("expr divide by zero");
  175. return -1;
  176. }
  177. else
  178. {
  179. r = a / b;
  180. }
  181. }
  182. else if (t == "max")
  183. {
  184. r = std::max(a, b);
  185. }
  186. else // if (t == "min")
  187. {
  188. r = std::min(a, b);
  189. }
  190. exprstack.push(r);
  191. }
  192. else
  193. {
  194. const float a = ta.type == 0 ? ta.i : ta.f;
  195. const float b = tb.type == 0 ? tb.i : tb.f;
  196. float r = 0.f;
  197. if (t == "+")
  198. {
  199. r = a + b;
  200. }
  201. else if (t == "-")
  202. {
  203. r = a - b;
  204. }
  205. else if (t == "*")
  206. {
  207. r = a * b;
  208. }
  209. else if (t == "//")
  210. {
  211. r = floorf(a / b);
  212. }
  213. else if (t == "max")
  214. {
  215. r = std::max(a, b);
  216. }
  217. else // if (t == "min")
  218. {
  219. r = std::min(a, b);
  220. }
  221. exprstack.push(r);
  222. }
  223. }
  224. else if (t == "abs" || t == "neg" || t == "sign" || t == "square")
  225. {
  226. typed_value ta = exprstack.top();
  227. exprstack.pop();
  228. if (ta.type == 0)
  229. {
  230. const int a = ta.i;
  231. int r = 0;
  232. if (t == "abs")
  233. {
  234. r = a > 0 ? a : -a;
  235. }
  236. else if (t == "neg")
  237. {
  238. r = -a;
  239. }
  240. else if (t == "sign")
  241. {
  242. r = a > 0 ? 1 : (a == 0 ? 0 : -1);
  243. }
  244. else // if (t == "square")
  245. {
  246. r = a * a;
  247. }
  248. exprstack.push(r);
  249. }
  250. else
  251. {
  252. const float a = ta.f;
  253. float r = 0;
  254. if (t == "abs")
  255. {
  256. r = fabsf(a);
  257. }
  258. else if (t == "neg")
  259. {
  260. r = -a;
  261. }
  262. else if (t == "sign")
  263. {
  264. r = a > 0.f ? 1 : (a == 0.f ? 0 : -1);
  265. }
  266. else // if (t == "square")
  267. {
  268. r = a * a;
  269. }
  270. exprstack.push(r);
  271. }
  272. }
  273. else if (t == "trunc" || t == "ceil" || t == "floor" || t == "round")
  274. {
  275. typed_value ta = exprstack.top();
  276. exprstack.pop();
  277. if (ta.type == 0)
  278. {
  279. const int a = ta.i;
  280. exprstack.push(a);
  281. }
  282. else
  283. {
  284. const float a = ta.f;
  285. int r = 0;
  286. if (t == "trunc")
  287. {
  288. r = (int)a;
  289. }
  290. else if (t == "ceil")
  291. {
  292. r = (int)ceil(a);
  293. }
  294. else if (t == "floor")
  295. {
  296. r = (int)floor(a);
  297. }
  298. else // if (t == "round")
  299. {
  300. r = (int)round(a);
  301. }
  302. exprstack.push(r);
  303. }
  304. }
  305. else if (t == "acos"
  306. || t == "acosh"
  307. || t == "asin"
  308. || t == "asinh"
  309. || t == "atan"
  310. || t == "atanh"
  311. || t == "cos"
  312. || t == "cosh"
  313. || t == "erf"
  314. || t == "exp"
  315. || t == "log"
  316. || t == "log10"
  317. || t == "reciprocal"
  318. || t == "rsqrt"
  319. || t == "sin"
  320. || t == "sinh"
  321. || t == "sqrt"
  322. || t == "tan"
  323. || t == "tanh")
  324. {
  325. typed_value ta = exprstack.top();
  326. exprstack.pop();
  327. const float a = ta.type == 0 ? ta.i : ta.f;
  328. float r = 0;
  329. if (t == "acos")
  330. {
  331. r = acosf(a);
  332. }
  333. else if (t == "acosh")
  334. {
  335. r = acoshf(a);
  336. }
  337. else if (t == "asin")
  338. {
  339. r = asinf(a);
  340. }
  341. else if (t == "asinh")
  342. {
  343. r = asinhf(a);
  344. }
  345. else if (t == "atan")
  346. {
  347. r = atanf(a);
  348. }
  349. else if (t == "atanh")
  350. {
  351. r = atanhf(a);
  352. }
  353. else if (t == "cos")
  354. {
  355. r = cosf(a);
  356. }
  357. else if (t == "cosh")
  358. {
  359. r = coshf(a);
  360. }
  361. else if (t == "erf")
  362. {
  363. r = erff(a);
  364. }
  365. else if (t == "exp")
  366. {
  367. r = expf(a);
  368. }
  369. else if (t == "log")
  370. {
  371. r = logf(a);
  372. }
  373. else if (t == "log10")
  374. {
  375. r = log10f(a);
  376. }
  377. else if (t == "reciprocal")
  378. {
  379. r = 1.f / a;
  380. }
  381. else if (t == "rsqrt")
  382. {
  383. r = 1.f / sqrtf(a);
  384. }
  385. else if (t == "sin")
  386. {
  387. r = sinf(a);
  388. }
  389. else if (t == "sinh")
  390. {
  391. r = sinhf(a);
  392. }
  393. else if (t == "sqrt")
  394. {
  395. r = sqrtf(a);
  396. }
  397. else if (t == "tan")
  398. {
  399. r = tanf(a);
  400. }
  401. else // if (t == "tanh")
  402. {
  403. r = tanhf(a);
  404. }
  405. exprstack.push(r);
  406. }
  407. else if (t == "/"
  408. || t == "atan2"
  409. || t == "fmod"
  410. || t == "pow"
  411. || t == "remainder"
  412. || t == "logaddexp")
  413. {
  414. typed_value ta = exprstack.top();
  415. exprstack.pop();
  416. typed_value tb = exprstack.top();
  417. exprstack.pop();
  418. const float a = ta.type == 0 ? ta.i : ta.f;
  419. const float b = tb.type == 0 ? tb.i : tb.f;
  420. float r = 0.f;
  421. if (t == "/")
  422. {
  423. r = a / b;
  424. }
  425. else if (t == "atan2")
  426. {
  427. r = atan2f(a, b);
  428. }
  429. else if (t == "fmod")
  430. {
  431. r = fmodf(a, b);
  432. }
  433. else if (t == "pow")
  434. {
  435. r = powf(a, b);
  436. }
  437. else if (t == "remainder")
  438. {
  439. r = fmodf(a, b);
  440. if (a * b < 0)
  441. r += b;
  442. }
  443. else // if (t == "logaddexp")
  444. {
  445. r = logf(expf(a) + expf(b));
  446. }
  447. exprstack.push(r);
  448. }
  449. else if (t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
  450. {
  451. typed_value ta = exprstack.top();
  452. exprstack.pop();
  453. typed_value tb = exprstack.top();
  454. exprstack.pop();
  455. // assert ta.type == 0 && tb.type == 0
  456. const int a = ta.i;
  457. const int b = tb.i;
  458. int r = 0;
  459. if (t == "and")
  460. {
  461. r = a & b;
  462. }
  463. else if (t == "or")
  464. {
  465. r = a | b;
  466. }
  467. else if (t == "xor")
  468. {
  469. r = a ^ b;
  470. }
  471. else if (t == "lshift")
  472. {
  473. r = a << b;
  474. }
  475. else // if (t == "rshift")
  476. {
  477. r = a >> b;
  478. }
  479. exprstack.push(r);
  480. }
  481. else
  482. {
  483. // literal
  484. int vi;
  485. float vf;
  486. int nscani = sscanf(t.c_str(), "%d", &vi);
  487. int nscanf = sscanf(t.c_str(), "%f", &vf);
  488. if (nscani == 1 && nscanf == 1 && vi == vf)
  489. {
  490. exprstack.push(vi);
  491. }
  492. else if (nscanf == 1)
  493. {
  494. exprstack.push(vf);
  495. }
  496. else
  497. {
  498. NCNN_LOGE("malformed literal token %s", t.c_str());
  499. return -1;
  500. }
  501. }
  502. }
  503. int size = exprstack.top().to_int();
  504. exprstack.pop();
  505. outlist.push_back(size);
  506. while (!exprstack.empty())
  507. {
  508. size = exprstack.top().to_int();
  509. exprstack.pop();
  510. outlist.push_back(size);
  511. }
  512. // NCNN_LOGE("shape %s = %d %d", expr.c_str(), list[0], list[1]);
  513. return 0;
  514. }
  515. } // namespace ncnn