You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_expression.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "fuse_expression.h"
  15. #include <algorithm>
  16. namespace pnnx {
  17. static bool operand_maybe_tensor(const Operand* operand)
  18. {
  19. const Operator* op = operand->producer;
  20. if (op->type == "prim::Constant")
  21. {
  22. const Parameter& param = op->params.at("value");
  23. if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4)
  24. {
  25. return false;
  26. }
  27. else
  28. {
  29. return true;
  30. }
  31. }
  32. if (op->type == "prim::NumToTensor")
  33. {
  34. return operand_maybe_tensor(op->inputs[0]);
  35. }
  36. if (op->type == "prim::ListConstruct")
  37. {
  38. return false;
  39. }
  40. if (op->type == "aten::size")
  41. {
  42. return false;
  43. }
  44. if (op->type == "aten::Int")
  45. {
  46. return operand_maybe_tensor(op->inputs[0]);
  47. }
  48. if (op->type == "aten::to" || op->type == "aten::detach")
  49. {
  50. return operand_maybe_tensor(op->inputs[0]);
  51. }
  52. if (op->type == "aten::ScalarImplicit")
  53. {
  54. return false;
  55. }
  56. if (op->type == "aten::abs"
  57. || op->type == "aten::acos"
  58. || op->type == "aten::acosh"
  59. || op->type == "aten::asin"
  60. || op->type == "aten::asinh"
  61. || op->type == "aten::atan"
  62. || op->type == "aten::atanh"
  63. || op->type == "aten::ceil"
  64. || op->type == "aten::cos"
  65. || op->type == "aten::cosh"
  66. || op->type == "aten::exp"
  67. || op->type == "aten::floor"
  68. || op->type == "aten::log"
  69. || op->type == "aten::log10"
  70. || op->type == "aten::neg"
  71. || op->type == "aten::reciprocal"
  72. || op->type == "aten::rsqrt"
  73. || op->type == "aten::sign"
  74. || op->type == "aten::sin"
  75. || op->type == "aten::sinh"
  76. || op->type == "aten::sqrt"
  77. || op->type == "aten::square"
  78. || op->type == "aten::tan"
  79. || op->type == "aten::tanh"
  80. || op->type == "aten::trunc")
  81. {
  82. return operand_maybe_tensor(op->inputs[0]);
  83. }
  84. if (op->type == "aten::atan2"
  85. || op->type == "aten::div"
  86. || op->type == "aten::floor_divide"
  87. || op->type == "aten::mul"
  88. || op->type == "aten::pow")
  89. {
  90. return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]);
  91. }
  92. if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__")
  93. {
  94. return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]);
  95. }
  96. if (op->type == "aten::add" || op->type == "aten::sub" || op->type == "aten::rsub")
  97. {
  98. return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]) || operand_maybe_tensor(op->inputs[2]);
  99. }
  100. return true;
  101. }
  102. static bool operand_is_foldable(const Operand* operand, const std::set<std::string>& foldable_constants)
  103. {
  104. if (foldable_constants.find(operand->name) != foldable_constants.end())
  105. return true;
  106. const Operator* op = operand->producer;
  107. if (op->type == "pnnx.Input")
  108. return false;
  109. for (auto x : op->inputs)
  110. {
  111. if (!operand_is_foldable(x, foldable_constants))
  112. return false;
  113. }
  114. return true;
  115. }
  116. static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, const std::set<std::string>& foldable_constants, bool checksubgraph = true)
  117. {
  118. // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str());
  119. Operator* op = operand->producer;
  120. if (checksubgraph && operand_maybe_tensor(operand))
  121. {
  122. if (op->outputs.size() > 1 || op->outputs[0]->consumers.size() > 1)
  123. {
  124. auto it = std::find(inputs.begin(), inputs.end(), operand);
  125. if (it == inputs.end())
  126. {
  127. // tensor
  128. char tmp[32];
  129. sprintf(tmp, "@%d", (int)inputs.size());
  130. expr += tmp;
  131. inputs.push_back(operand);
  132. }
  133. else
  134. {
  135. // tensor
  136. char tmp[32];
  137. sprintf(tmp, "@%d", (int)(it - inputs.begin()));
  138. expr += tmp;
  139. }
  140. return;
  141. }
  142. }
  143. if (op->type == "prim::Constant")
  144. {
  145. const Parameter& param = op->params["value"];
  146. // fprintf(stderr, "fuse_expression prim::Constant %d\n", param.type);
  147. if (param.type == 0)
  148. {
  149. expr += "None";
  150. }
  151. else if (param.type == 1)
  152. {
  153. expr += param.b ? "True" : "False";
  154. }
  155. else if (param.type == 2)
  156. {
  157. char tmp[32];
  158. sprintf(tmp, "%d", param.i);
  159. expr += tmp;
  160. }
  161. else if (param.type == 3)
  162. {
  163. char tmp[32];
  164. sprintf(tmp, "%e", param.f);
  165. expr += tmp;
  166. }
  167. else if (param.type == 4)
  168. {
  169. expr += param.s;
  170. }
  171. else
  172. {
  173. auto it = std::find(inputs.begin(), inputs.end(), operand);
  174. if (it == inputs.end())
  175. {
  176. // tensor
  177. char tmp[32];
  178. sprintf(tmp, "@%d", (int)inputs.size());
  179. expr += tmp;
  180. inputs.push_back(operand);
  181. }
  182. else
  183. {
  184. // tensor
  185. char tmp[32];
  186. sprintf(tmp, "@%d", (int)(it - inputs.begin()));
  187. expr += tmp;
  188. }
  189. }
  190. }
  191. else if (checksubgraph && operand_maybe_tensor(operand) && operand_is_foldable(operand, foldable_constants))
  192. {
  193. // fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str());
  194. auto it = std::find(inputs.begin(), inputs.end(), operand);
  195. if (it == inputs.end())
  196. {
  197. // tensor
  198. char tmp[32];
  199. sprintf(tmp, "@%d", (int)inputs.size());
  200. expr += tmp;
  201. inputs.push_back(operand);
  202. }
  203. else
  204. {
  205. // tensor
  206. char tmp[32];
  207. sprintf(tmp, "@%d", (int)(it - inputs.begin()));
  208. expr += tmp;
  209. }
  210. }
  211. else if (op->type == "prim::NumToTensor")
  212. {
  213. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  214. }
  215. else if (op->type == "prim::ListConstruct")
  216. {
  217. expr += "[";
  218. for (int i = 0; i < (int)op->inputs.size() - 1; i++)
  219. {
  220. fuse_expression(graph, op->inputs[i], expr, inputs, foldable_constants);
  221. expr += ",";
  222. }
  223. if (op->inputs.size() > 0)
  224. {
  225. fuse_expression(graph, op->inputs[op->inputs.size() - 1], expr, inputs, foldable_constants);
  226. }
  227. expr += "]";
  228. }
  229. else if (op->type == "aten::size")
  230. {
  231. expr += "size(";
  232. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  233. expr += ",";
  234. fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants);
  235. expr += ")";
  236. }
  237. else if (op->type == "aten::Int")
  238. {
  239. expr += "int(";
  240. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  241. expr += ")";
  242. }
  243. else if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
  244. {
  245. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  246. }
  247. else if (op->type == "aten::abs"
  248. || op->type == "aten::acos"
  249. || op->type == "aten::acosh"
  250. || op->type == "aten::asin"
  251. || op->type == "aten::asinh"
  252. || op->type == "aten::atan"
  253. || op->type == "aten::atanh"
  254. || op->type == "aten::ceil"
  255. || op->type == "aten::cos"
  256. || op->type == "aten::cosh"
  257. || op->type == "aten::exp"
  258. || op->type == "aten::floor"
  259. || op->type == "aten::log"
  260. || op->type == "aten::log10"
  261. || op->type == "aten::neg"
  262. || op->type == "aten::reciprocal"
  263. || op->type == "aten::rsqrt"
  264. || op->type == "aten::sign"
  265. || op->type == "aten::sin"
  266. || op->type == "aten::sinh"
  267. || op->type == "aten::sqrt"
  268. || op->type == "aten::square"
  269. || op->type == "aten::tan"
  270. || op->type == "aten::tanh"
  271. || op->type == "aten::trunc")
  272. {
  273. std::string mathop = op->type.substr(6);
  274. expr += mathop;
  275. expr += "(";
  276. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  277. expr += ")";
  278. }
  279. else if (op->type == "aten::atan2"
  280. || op->type == "aten::div"
  281. || op->type == "aten::floor_divide"
  282. || op->type == "aten::mul"
  283. || op->type == "aten::pow"
  284. || op->type == "aten::remainder")
  285. {
  286. std::string mathop = op->type.substr(6);
  287. expr += mathop;
  288. expr += "(";
  289. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  290. expr += ",";
  291. fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants);
  292. expr += ")";
  293. }
  294. else if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__")
  295. {
  296. std::string mathop = op->type.substr(8, op->type.size() - 10);
  297. expr += mathop;
  298. expr += "(";
  299. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  300. expr += ",";
  301. fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants);
  302. expr += ")";
  303. }
  304. else if (op->type == "aten::add" || op->type == "aten::sub")
  305. {
  306. std::string mathop = op->type.substr(6);
  307. expr += mathop;
  308. expr += "(";
  309. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  310. expr += ",";
  311. std::string expr1;
  312. std::string expr2;
  313. fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants);
  314. fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants);
  315. if (expr2 == "1")
  316. {
  317. expr += expr1;
  318. }
  319. else
  320. {
  321. expr += ",";
  322. expr += "mul(";
  323. expr += expr1;
  324. expr += ",";
  325. expr += expr2;
  326. expr += ")";
  327. }
  328. expr += ")";
  329. }
  330. else if (op->type == "aten::rsub")
  331. {
  332. expr += "sub(";
  333. std::string expr1;
  334. std::string expr2;
  335. fuse_expression(graph, op->inputs[1], expr1, inputs, foldable_constants);
  336. fuse_expression(graph, op->inputs[2], expr2, inputs, foldable_constants);
  337. if (expr2 == "1")
  338. {
  339. expr += expr1;
  340. }
  341. else
  342. {
  343. expr += ",";
  344. expr += "mul(";
  345. expr += expr1;
  346. expr += ",";
  347. expr += expr2;
  348. expr += ")";
  349. }
  350. expr += ",";
  351. fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants);
  352. expr += ")";
  353. }
  354. else
  355. {
  356. auto it = std::find(inputs.begin(), inputs.end(), operand);
  357. if (it == inputs.end())
  358. {
  359. // tensor
  360. char tmp[32];
  361. sprintf(tmp, "@%d", (int)inputs.size());
  362. expr += tmp;
  363. inputs.push_back(operand);
  364. }
  365. else
  366. {
  367. // tensor
  368. char tmp[32];
  369. sprintf(tmp, "@%d", (int)(it - inputs.begin()));
  370. expr += tmp;
  371. }
  372. }
  373. }
  374. void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constants)
  375. {
  376. int pnnx_expr_index = 0;
  377. for (;;)
  378. {
  379. bool need_fuse = false;
  380. // build expression via reverse order
  381. for (int i = (int)graph.ops.size() - 1; i >= 0; i--)
  382. {
  383. Operator* op = graph.ops[i];
  384. if (op->type == "prim::Constant")
  385. {
  386. need_fuse = true;
  387. }
  388. if (op->type == "prim::NumToTensor")
  389. {
  390. need_fuse = true;
  391. }
  392. if (op->type == "prim::ListConstruct")
  393. {
  394. need_fuse = true;
  395. }
  396. if (op->type == "aten::size")
  397. {
  398. need_fuse = true;
  399. }
  400. if (op->type == "aten::Int")
  401. {
  402. need_fuse = true;
  403. }
  404. if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
  405. {
  406. need_fuse = true;
  407. }
  408. if (op->type == "aten::abs"
  409. || op->type == "aten::acos"
  410. || op->type == "aten::acosh"
  411. || op->type == "aten::add"
  412. || op->type == "aten::asin"
  413. || op->type == "aten::asinh"
  414. || op->type == "aten::atan"
  415. || op->type == "aten::atanh"
  416. || op->type == "aten::atan2"
  417. || op->type == "aten::ceil"
  418. || op->type == "aten::cos"
  419. || op->type == "aten::cosh"
  420. || op->type == "aten::div"
  421. || op->type == "aten::exp"
  422. || op->type == "aten::floor"
  423. || op->type == "aten::floor_divide"
  424. || op->type == "aten::log"
  425. || op->type == "aten::log10"
  426. || op->type == "aten::mul"
  427. || op->type == "aten::neg"
  428. || op->type == "aten::pow"
  429. || op->type == "aten::reciprocal"
  430. || op->type == "aten::remainder"
  431. || op->type == "aten::rsqrt"
  432. || op->type == "aten::rsub"
  433. || op->type == "aten::sign"
  434. || op->type == "aten::sin"
  435. || op->type == "aten::sinh"
  436. || op->type == "aten::sqrt"
  437. || op->type == "aten::square"
  438. || op->type == "aten::sub"
  439. || op->type == "aten::tan"
  440. || op->type == "aten::tanh"
  441. || op->type == "aten::trunc")
  442. {
  443. need_fuse = true;
  444. }
  445. if (op->type == "aten::__and__" || op->type == "aten::__or__" || op->type == "aten::__xor__" || op->type == "aten::__lshift__" || op->type == "aten::__rshift__")
  446. {
  447. need_fuse = true;
  448. }
  449. if (need_fuse)
  450. {
  451. std::string expr;
  452. std::vector<Operand*> inputs;
  453. fuse_expression(graph, op->outputs[0], expr, inputs, foldable_constants, false);
  454. // fprintf(stderr, "expr = %s\n", expr.c_str());
  455. // lets rewrite graph
  456. char name[32];
  457. sprintf(name, "pnnx_expr_%d", pnnx_expr_index++);
  458. op->type = "pnnx.Expression";
  459. op->name = name;
  460. op->params.clear();
  461. op->attrs.clear();
  462. op->params["expr"] = expr;
  463. // fix input output
  464. for (Operand* operand : op->inputs)
  465. {
  466. operand->consumers.erase(std::find(operand->consumers.begin(), operand->consumers.end(), op));
  467. }
  468. op->inputs = inputs;
  469. for (Operand* operand : op->inputs)
  470. {
  471. operand->consumers.push_back(op);
  472. }
  473. break;
  474. }
  475. }
  476. if (!need_fuse)
  477. break;
  478. }
  479. }
  480. } // namespace pnnx