You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass_level2.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "pass_level2.h"
  15. #include <algorithm>
  16. #include <map>
  17. #include <unordered_map>
  18. namespace pnnx {
  19. GraphRewriterPass::~GraphRewriterPass()
  20. {
  21. }
  22. const char* GraphRewriterPass::name_str() const
  23. {
  24. return type_str();
  25. }
  26. bool GraphRewriterPass::match(const std::map<std::string, Parameter>& /*captured_params*/) const
  27. {
  28. return true;
  29. }
  30. bool GraphRewriterPass::match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
  31. {
  32. return match(captured_params);
  33. }
  34. bool GraphRewriterPass::match(const std::map<std::string, const Operator*>& /*matched_operators*/) const
  35. {
  36. return true;
  37. }
  38. void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
  39. {
  40. for (auto x : captured_params)
  41. {
  42. op->params[x.first] = x.second;
  43. }
  44. }
  45. void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
  46. {
  47. write(op, captured_params);
  48. }
  49. static std::map<int, std::vector<const GraphRewriterPass*> > g_global_pnnx_graph_rewriter_passes;
  50. GraphRewriterPassRegister::GraphRewriterPassRegister(const GraphRewriterPass* _pass, int priority)
  51. : pass(_pass)
  52. {
  53. if (g_global_pnnx_graph_rewriter_passes.find(priority) == g_global_pnnx_graph_rewriter_passes.end())
  54. {
  55. g_global_pnnx_graph_rewriter_passes[priority] = std::vector<const GraphRewriterPass*>();
  56. }
  57. g_global_pnnx_graph_rewriter_passes[priority].push_back(pass);
  58. }
  59. GraphRewriterPassRegister::~GraphRewriterPassRegister()
  60. {
  61. delete pass;
  62. }
  63. static bool match_parameter(const Parameter& a, const Parameter& b, std::map<std::string, Parameter>& captured_params)
  64. {
  65. if (b.type == 4 && b.s[0] == '%')
  66. {
  67. // captured parameter
  68. captured_params[b.s.substr(1)] = a;
  69. return true;
  70. }
  71. if (b.type == 4 && b.s == "*")
  72. {
  73. // ignored parameter
  74. return true;
  75. }
  76. if (a.type != b.type)
  77. {
  78. if (a.type == 2 && b.type == 3)
  79. return a.i == b.f;
  80. if (a.type == 3 && b.type == 2)
  81. return a.f == b.i;
  82. return false;
  83. }
  84. const int type = a.type;
  85. if (type == 0)
  86. {
  87. return true;
  88. }
  89. if (type == 1)
  90. {
  91. return a.b == b.b;
  92. }
  93. if (type == 2)
  94. {
  95. return a.i == b.i;
  96. }
  97. if (type == 3)
  98. {
  99. return a.f == b.f;
  100. }
  101. if (type == 4)
  102. {
  103. return a.s == b.s;
  104. }
  105. if (type == 5)
  106. {
  107. if (a.ai.size() != b.ai.size())
  108. return false;
  109. for (size_t i = 0; i < a.ai.size(); i++)
  110. {
  111. if (a.ai[i] != b.ai[i])
  112. return false;
  113. }
  114. return true;
  115. }
  116. if (type == 6)
  117. {
  118. if (a.af.size() != b.af.size())
  119. return false;
  120. for (size_t i = 0; i < a.af.size(); i++)
  121. {
  122. if (a.af[i] != b.af[i])
  123. return false;
  124. }
  125. return true;
  126. }
  127. if (type == 7)
  128. {
  129. if (a.as.size() != b.as.size())
  130. return false;
  131. for (size_t i = 0; i < a.as.size(); i++)
  132. {
  133. if (a.as[i] != b.as[i])
  134. return false;
  135. }
  136. return true;
  137. }
  138. // unknown
  139. return false;
  140. }
  141. static bool match_operator(const Operator* a, const Operator* b, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
  142. {
  143. if (a->type != b->type)
  144. return false;
  145. if (a->inputs.size() != b->inputs.size())
  146. return false;
  147. if (a->outputs.size() != b->outputs.size())
  148. return false;
  149. // match params
  150. if (b->params.size() == 1 && b->params.find("%*") != b->params.end() && b->params.at("%*").type == 4 && b->params.at("%*").s == "%*")
  151. {
  152. for (const auto& p : a->params)
  153. {
  154. const std::string& pkey = p.first;
  155. const Parameter& pp = p.second;
  156. // capture all parameters
  157. captured_params[b->name + '.' + pkey] = pp;
  158. }
  159. }
  160. else
  161. {
  162. if (a->params.size() != b->params.size())
  163. return false;
  164. for (const auto& p : a->params)
  165. {
  166. const std::string& akey = p.first;
  167. const Parameter& ap = p.second;
  168. if (b->params.find(akey) == b->params.end())
  169. return false;
  170. if (!match_parameter(ap, b->params.at(akey), captured_params))
  171. return false;
  172. }
  173. }
  174. for (const auto& p : a->attrs)
  175. {
  176. const std::string& akey = p.first;
  177. const Attribute& aa = p.second;
  178. // capture all attributes
  179. captured_attrs[b->name + '.' + akey] = aa;
  180. }
  181. return true;
  182. }
  183. static bool match(const Operator* anchor, const Operator* pattern, std::map<std::string, const Operator*>& matched_operators, std::map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
  184. {
  185. if (!match_operator(anchor, pattern, captured_params, captured_attrs))
  186. return false;
  187. for (size_t i = 0; i < pattern->outputs.size(); i++)
  188. {
  189. if (pattern->outputs[i]->consumers.size() == 1 && pattern->outputs[i]->consumers[0]->type == "pnnx.Output")
  190. continue;
  191. if (anchor->outputs[i]->consumers.size() != pattern->outputs[i]->consumers.size())
  192. return false;
  193. }
  194. matched_operators[pattern->name] = anchor;
  195. // lets match
  196. for (size_t i = 0; i < pattern->inputs.size(); i++)
  197. {
  198. const Operator* anchor2 = anchor->inputs[i]->producer;
  199. const Operator* pattern2 = pattern->inputs[i]->producer;
  200. if (pattern2->type == "pnnx.Input")
  201. {
  202. if (matched_inputs.find(pattern->inputs[i]->name) == matched_inputs.end())
  203. {
  204. matched_inputs[pattern->inputs[i]->name] = anchor->inputs[i];
  205. }
  206. else if (matched_inputs[pattern->inputs[i]->name] != anchor->inputs[i])
  207. {
  208. return false;
  209. }
  210. continue;
  211. }
  212. if (!match(anchor2, pattern2, matched_operators, matched_inputs, captured_params, captured_attrs))
  213. return false;
  214. }
  215. return true;
  216. }
  217. void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opindex)
  218. {
  219. Graph pattern_graph;
  220. pattern_graph.parse(pass->match_pattern_graph());
  221. // collect pattern inputs and outputs order
  222. std::vector<std::string> pattern_graph_inputs;
  223. std::vector<std::string> pattern_graph_outputs;
  224. std::vector<const Operator*> pattern_graph_output_operators;
  225. for (const auto& x : pattern_graph.ops)
  226. {
  227. if (x->type == "pnnx.Input")
  228. {
  229. for (const auto& y : x->outputs)
  230. pattern_graph_inputs.push_back(y->name);
  231. }
  232. if (x->type == "pnnx.Output")
  233. {
  234. pattern_graph_output_operators.push_back(x);
  235. for (const auto& y : x->inputs)
  236. pattern_graph_outputs.push_back(y->name);
  237. }
  238. }
  239. std::vector<Operator*> new_ops;
  240. while (1)
  241. {
  242. const int graph_op_count = (int)graph.ops.size();
  243. bool matched = true;
  244. // lets match from output
  245. std::map<std::string, const Operator*> matched_operators;
  246. std::map<std::string, const Operand*> matched_inputs;
  247. std::map<std::string, const Operand*> matched_outputs;
  248. std::map<std::string, Parameter> captured_params;
  249. std::map<std::string, Attribute> captured_attrs;
  250. // pattern match from end to beginning
  251. int q = graph_op_count - 1;
  252. for (; q >= 1; q--)
  253. {
  254. for (const Operator* pattern : pattern_graph_output_operators)
  255. {
  256. for (size_t i = 0; i < pattern->inputs.size(); i++)
  257. {
  258. const Operator* pattern2 = pattern->inputs[i]->producer;
  259. int j = q;
  260. for (; j >= 0; j--)
  261. {
  262. const Operator* anchor = graph.ops[j];
  263. std::map<std::string, const Operator*> matched_operators2;
  264. std::map<std::string, const Operand*> matched_inputs2;
  265. std::map<std::string, Parameter> captured_params2;
  266. std::map<std::string, Attribute> captured_attrs2;
  267. if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2))
  268. continue;
  269. bool submatch_matched = true;
  270. for (auto x : matched_operators2)
  271. {
  272. // check these matched operators are same with previous matched ones
  273. if (matched_operators.find(x.first) != matched_operators.end())
  274. {
  275. if (matched_operators[x.first] != x.second)
  276. {
  277. // unmatched two sub-matches
  278. submatch_matched = false;
  279. break;
  280. }
  281. }
  282. else
  283. {
  284. matched_operators[x.first] = x.second;
  285. }
  286. }
  287. if (!submatch_matched)
  288. continue;
  289. for (auto x : matched_inputs2)
  290. {
  291. if (matched_inputs.find(x.first) == matched_inputs.end())
  292. {
  293. matched_inputs[x.first] = x.second;
  294. }
  295. }
  296. for (auto x : captured_params2)
  297. {
  298. captured_params[x.first] = x.second;
  299. }
  300. for (auto x : captured_attrs2)
  301. {
  302. captured_attrs[x.first] = x.second;
  303. }
  304. // match !
  305. matched_outputs[pattern->inputs[i]->name] = anchor->outputs[i];
  306. break;
  307. }
  308. if (j == -1)
  309. {
  310. matched = false;
  311. break;
  312. }
  313. }
  314. if (!matched)
  315. break;
  316. }
  317. if (matched && (!pass->match(captured_params, captured_attrs) || !pass->match(matched_operators)))
  318. {
  319. matched_operators.clear();
  320. matched_inputs.clear();
  321. matched_outputs.clear();
  322. captured_params.clear();
  323. captured_attrs.clear();
  324. continue;
  325. }
  326. break;
  327. }
  328. if (!matched)
  329. break;
  330. // fprintf(stderr, "matched !\n");
  331. // lets replace
  332. // remove all operands inside matched graph
  333. std::map<std::string, Operand*> operands_to_remove;
  334. for (auto& _x : matched_operators)
  335. {
  336. Operator* x = (Operator*)_x.second;
  337. for (auto& r : x->inputs)
  338. {
  339. r->remove_consumer(x);
  340. bool is_input = false;
  341. for (auto& r2 : matched_inputs)
  342. {
  343. if (r2.second == r)
  344. {
  345. is_input = true;
  346. break;
  347. }
  348. }
  349. if (!is_input)
  350. operands_to_remove[r->name] = r;
  351. }
  352. x->inputs.clear();
  353. for (auto& r : x->outputs)
  354. {
  355. r->producer = 0;
  356. bool is_output = false;
  357. for (auto& r2 : matched_outputs)
  358. {
  359. if (r2.second == r)
  360. {
  361. is_output = true;
  362. break;
  363. }
  364. }
  365. if (!is_output)
  366. operands_to_remove[r->name] = r;
  367. }
  368. x->outputs.clear();
  369. }
  370. for (auto& _x : operands_to_remove)
  371. {
  372. Operand* r = _x.second;
  373. graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), r));
  374. delete r;
  375. }
  376. // remove all matched_operators
  377. for (auto& _x : matched_operators)
  378. {
  379. // fprintf(stderr, "remove %s\n", _x.second->name.c_str());
  380. Operator* x = (Operator*)_x.second;
  381. graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), x));
  382. delete _x.second;
  383. }
  384. // insert new operator before all output consumers
  385. const Operator* cur = 0;
  386. {
  387. int cur_index = graph.ops.size() - 1;
  388. for (auto& o : matched_outputs)
  389. {
  390. for (auto& c : o.second->consumers)
  391. {
  392. int c_index = std::find(graph.ops.begin(), graph.ops.end(), c) - graph.ops.begin();
  393. cur_index = std::min(cur_index, c_index);
  394. }
  395. }
  396. cur = graph.ops[cur_index];
  397. }
  398. Operator* op = graph.new_operator_before(pass->type_str(), std::string(pass->name_str()), cur);
  399. for (const auto& k : pattern_graph_inputs)
  400. {
  401. Operand* r = (Operand*)matched_inputs.at(k);
  402. r->consumers.push_back(op);
  403. op->inputs.push_back(r);
  404. op->inputnames.push_back(k);
  405. }
  406. for (const auto& k : pattern_graph_outputs)
  407. {
  408. Operand* r = (Operand*)matched_outputs.at(k);
  409. r->producer = op;
  410. op->outputs.push_back(r);
  411. }
  412. pass->write(op, captured_params, captured_attrs);
  413. new_ops.push_back(op);
  414. }
  415. // assign new op name number
  416. for (int i = (int)new_ops.size() - 1; i >= 0; i--)
  417. {
  418. new_ops[i]->name = new_ops[i]->name + "_" + std::to_string(opindex++);
  419. }
  420. }
  421. static void fix_inplace_copy_output(Graph& graph)
  422. {
  423. while (1)
  424. {
  425. bool matched = false;
  426. for (size_t i = 0; i < graph.ops.size(); i++)
  427. {
  428. Operator* op = graph.ops[i];
  429. bool is_inplace_op = op->type.size() > 2 && op->type[op->type.size() - 2] != '_' && op->type[op->type.size() - 1] == '_';
  430. if (!is_inplace_op)
  431. continue;
  432. // replace inplace op with non-inplace version
  433. op->type = op->type.substr(0, op->type.size() - 1);
  434. if (op->type == "aten::copy")
  435. continue;
  436. if (op->outputs[0]->consumers.size() != 0)
  437. continue;
  438. matched = true;
  439. // find in0 from slice / select chain
  440. Operand* in0 = op->inputs[0];
  441. while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select")
  442. {
  443. in0 = in0->producer->inputs[0];
  444. }
  445. // append copy for inplace op
  446. Operator* op_copy = graph.new_operator_after("aten::copy", op->name + "_copy", op);
  447. Operand* copy_out = graph.new_operand(op->name + "_copy_out");
  448. copy_out->shape = in0->shape;
  449. op_copy->inputs.push_back(op->inputs[0]);
  450. op_copy->inputs.push_back(op->outputs[0]);
  451. op->inputs[0]->consumers.push_back(op_copy);
  452. op->outputs[0]->consumers.push_back(op_copy);
  453. op_copy->outputs.push_back(copy_out);
  454. copy_out->producer = op_copy;
  455. break;
  456. }
  457. if (!matched)
  458. break;
  459. }
  460. for (size_t i = 0; i < graph.ops.size(); i++)
  461. {
  462. Operator* op = graph.ops[i];
  463. if (op->type != "aten::copy")
  464. continue;
  465. if (op->outputs[0]->consumers.size() != 0)
  466. continue;
  467. // aten::slice 5 1 in0 .... a
  468. // aten::slice 5 1 a .... b
  469. // aten::copy 2 1 b in1 out
  470. // aten::select 3 1 in0 .... a
  471. // aten::copy 2 1 a in1 out
  472. // find in0 from slice / select chain
  473. Operand* in0 = op->inputs[0];
  474. while (in0->producer->type == "aten::slice" || in0->producer->type == "aten::select")
  475. {
  476. in0 = in0->producer->inputs[0];
  477. }
  478. // replace all the following uses of in0 with out
  479. Operand* out0 = op->outputs[0];
  480. out0->shape = in0->shape;
  481. for (size_t j = i; j < graph.ops.size(); j++)
  482. {
  483. Operator* op2 = graph.ops[j];
  484. bool use_in0 = false;
  485. for (size_t k = 0; k < op2->inputs.size(); k++)
  486. {
  487. if (op2->inputs[k] == in0)
  488. {
  489. op2->inputs[k] = out0;
  490. use_in0 = true;
  491. }
  492. }
  493. if (use_in0)
  494. {
  495. in0->remove_consumer(op2);
  496. out0->consumers.push_back(op2);
  497. }
  498. }
  499. }
  500. }
  501. void pass_level2(Graph& g)
  502. {
  503. fix_inplace_copy_output(g);
  504. int opindex = 0;
  505. for (auto x : g_global_pnnx_graph_rewriter_passes)
  506. {
  507. for (auto rewriter : x.second)
  508. {
  509. pnnx_graph_rewrite(g, rewriter, opindex);
  510. }
  511. }
  512. }
  513. } // namespace pnnx