You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlir2ncnn.cpp 53 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <map>
  16. #include <set>
  17. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  18. #include <mlir/IR/Module.h>
  19. #include <mlir/IR/PatternMatch.h>
  20. #include <mlir/Parser.h>
  21. #include <mlir/Pass/PassManager.h>
  22. #include <mlir/Transforms/Passes.h>
  23. #include "tf_dialect.h"
  24. #include "ncnn_dialect.h"
  25. static std::string get_mlir_value_uniq_id(const mlir::Value& value)
  26. {
  27. if (value.getLoc().isa<mlir::FileLineColLoc>())
  28. {
  29. mlir::FileLineColLoc floc = value.getLoc().cast<mlir::FileLineColLoc>();
  30. return floc.getFilename().str() + ":" + std::to_string(floc.getLine()) + ":" + std::to_string(floc.getColumn());
  31. }
  32. if (value.getLoc().isa<mlir::FusedLoc>())
  33. {
  34. mlir::FileLineColLoc floc = value.getLoc().cast<mlir::FusedLoc>().getLocations().front().cast<mlir::FileLineColLoc>();
  35. return floc.getFilename().str() + ":" + std::to_string(floc.getLine()) + ":" + std::to_string(floc.getColumn());
  36. }
  37. fprintf(stderr, "unhandled get_mlir_value_uniq_id\n");
  38. return std::string();
  39. }
  40. static std::string get_attr_s(const mlir::Attribute& attr)
  41. {
  42. std::string s;
  43. if (attr.isa<mlir::StringAttr>())
  44. {
  45. mlir::StringAttr a = attr.cast<mlir::StringAttr>();
  46. s = a.getValue().str();
  47. }
  48. return s;
  49. }
  50. static int get_attr_b(const mlir::Attribute& attr)
  51. {
  52. int i;
  53. if (attr.isa<mlir::BoolAttr>())
  54. {
  55. mlir::BoolAttr a = attr.cast<mlir::BoolAttr>();
  56. i = a.getValue() ? 1 : 0;
  57. }
  58. else
  59. {
  60. fprintf(stderr, "not BoolAttr\n");
  61. }
  62. return i;
  63. }
  64. static int get_attr_i(const mlir::Attribute& attr)
  65. {
  66. int i;
  67. if (attr.isa<mlir::IntegerAttr>())
  68. {
  69. mlir::IntegerAttr a = attr.cast<mlir::IntegerAttr>();
  70. i = (int)a.getInt();
  71. }
  72. else
  73. {
  74. fprintf(stderr, "not IntegerAttr\n");
  75. }
  76. return i;
  77. }
  78. static float get_attr_f(const mlir::Attribute& attr)
  79. {
  80. float f;
  81. if (attr.isa<mlir::FloatAttr>())
  82. {
  83. mlir::FloatAttr a = attr.cast<mlir::FloatAttr>();
  84. f = (float)a.getValueAsDouble();
  85. }
  86. else
  87. {
  88. fprintf(stderr, "not FloatAttr\n");
  89. }
  90. return f;
  91. }
  92. static std::vector<int> get_attr_ai(const mlir::Attribute& attr)
  93. {
  94. std::vector<int> v;
  95. if (attr.isa<mlir::ArrayAttr>())
  96. {
  97. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  98. const int array_size = a.getValue().size();
  99. v.resize(array_size);
  100. for (int j = 0; j < array_size; j++)
  101. {
  102. if (a[j].isa<mlir::IntegerAttr>())
  103. {
  104. int64_t ii = a[j].cast<mlir::IntegerAttr>().getInt();
  105. v[j] = std::max(std::min(ii, (int64_t)INT_MAX), (int64_t)INT_MIN);
  106. }
  107. }
  108. }
  109. else if (attr.isa<mlir::DenseIntElementsAttr>())
  110. {
  111. mlir::DenseIntElementsAttr ai = attr.cast<mlir::DenseIntElementsAttr>();
  112. for (auto ii : ai.getIntValues())
  113. {
  114. v.push_back(ii.getSExtValue());
  115. }
  116. }
  117. else
  118. {
  119. fprintf(stderr, "not ArrayAttr or DenseIntElementsAttr\n");
  120. }
  121. return v;
  122. }
  123. static std::vector<float> get_attr_af(const mlir::Attribute& attr)
  124. {
  125. std::vector<float> v;
  126. if (attr.isa<mlir::ArrayAttr>())
  127. {
  128. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  129. const int array_size = a.getValue().size();
  130. v.resize(array_size);
  131. for (int j = 0; j < array_size; j++)
  132. {
  133. if (a[j].isa<mlir::FloatAttr>())
  134. {
  135. double ff = a[j].cast<mlir::FloatAttr>().getValueAsDouble();
  136. v[j] = ff;
  137. }
  138. }
  139. }
  140. else if (attr.isa<mlir::DenseFPElementsAttr>())
  141. {
  142. mlir::DenseFPElementsAttr af = attr.cast<mlir::DenseFPElementsAttr>();
  143. for (auto ff : af.getFloatValues())
  144. {
  145. v.push_back(ff.convertToFloat());
  146. }
  147. }
  148. else
  149. {
  150. fprintf(stderr, "not ArrayAttr or DenseFPElementsAttr\n");
  151. }
  152. return v;
  153. }
  154. static std::string get_operation_attr_s(const mlir::Operation& _operation, const char* key)
  155. {
  156. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  157. mlir::Attribute attr = operation.getAttr(key);
  158. return get_attr_s(attr);
  159. }
  160. static int get_operation_attr_b(const mlir::Operation& _operation, const char* key)
  161. {
  162. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  163. mlir::Attribute attr = operation.getAttr(key);
  164. return get_attr_b(attr);
  165. }
  166. static int get_operation_attr_i(const mlir::Operation& _operation, const char* key)
  167. {
  168. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  169. mlir::Attribute attr = operation.getAttr(key);
  170. return get_attr_i(attr);
  171. }
  172. static float get_operation_attr_f(const mlir::Operation& _operation, const char* key)
  173. {
  174. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  175. mlir::Attribute attr = operation.getAttr(key);
  176. return get_attr_f(attr);
  177. }
  178. static std::vector<int> get_operation_attr_ai(const mlir::Operation& _operation, const char* key)
  179. {
  180. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  181. mlir::Attribute attr = operation.getAttr(key);
  182. return get_attr_ai(attr);
  183. }
  184. static std::vector<float> get_operation_attr_af(const mlir::Operation& _operation, const char* key)
  185. {
  186. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  187. mlir::Attribute attr = operation.getAttr(key);
  188. return get_attr_af(attr);
  189. }
  190. int main(int argc, char** argv)
  191. {
  192. const char* mlirpath = argv[1];
  193. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
  194. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  195. mlir::registerDialect<mlir::StandardOpsDialect>();
  196. mlir::registerDialect<mlir::TF::TensorFlowDialect>();
  197. mlir::registerDialect<mlir::ncnn::NCNNDialect>();
  198. mlir::MLIRContext context;
  199. mlir::OwningModuleRef m = mlir::parseSourceFile(mlirpath, &context);
  200. mlir::PassManager pm(&context);
  201. // Apply any generic pass manager command line options and run the pipeline.
  202. applyPassManagerCLOptions(pm);
  203. // Add a run of the canonicalizer to optimize the mlir module.
  204. pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
  205. pm.run(*m);
  206. // m->dump();
  207. mlir::FuncOp main_fn = m->lookupSymbol<mlir::FuncOp>("main");
  208. auto& bb = main_fn.getBlocks().front();
  209. // bb.dump();
  210. FILE* pp = fopen(ncnn_prototxt, "wb");
  211. FILE* bp = fopen(ncnn_modelbin, "wb");
  212. // node reference
  213. std::map<std::string, int> node_reference;
  214. // weight node and weight reshape node
  215. std::map<std::string, mlir::Attribute> weights;
  216. // weight node before BinaryOp
  217. std::map<std::string, mlir::Attribute> binaryop_weights;
  218. fprintf(pp, "7767517\n");
  219. const mlir::Block::OpListType& operations = bb.getOperations();
  220. int node_count = operations.size();
  221. // global definition line
  222. // [layer count] [blob count]
  223. std::set<std::string> blob_names;
  224. for (const mlir::Operation& _operation : operations)
  225. {
  226. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  227. std::string op = operation.getName().getStringRef().str();
  228. int num_input = (int)operation.getNumOperands();
  229. int num_output = (int)operation.getNumResults();
  230. if (op == "tf.Const")
  231. {
  232. // weight
  233. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  234. weights[output_name] = operation.getAttr("value");
  235. continue;
  236. }
  237. else
  238. {
  239. bool isBinaryOp = false;
  240. // TODO add more binaryop
  241. if (op == "tf.BiasAdd" || op == "tf.AddV2" || op == "tf.Sub" || op == "tf.Maximum" || op == "tf.Minimum" || op == "tf.Mul")
  242. {
  243. isBinaryOp = true;
  244. }
  245. if (isBinaryOp)
  246. {
  247. // check weights
  248. for (int j = 0; j < num_input; j++)
  249. {
  250. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(j));
  251. std::map<std::string, mlir::Attribute>::iterator it = weights.find(input_name);
  252. if (it != weights.end())
  253. {
  254. // binary op with weight, insert MemoryData layer and const blob
  255. binaryop_weights[input_name] = it->second;
  256. weights.erase(it);
  257. }
  258. }
  259. }
  260. }
  261. for (int j = 0; j < num_input; j++)
  262. {
  263. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(j));
  264. // check weight
  265. if (weights.find(input_name) != weights.end())
  266. {
  267. continue;
  268. }
  269. blob_names.insert(input_name);
  270. if (node_reference.find(input_name) == node_reference.end())
  271. {
  272. node_reference[input_name] = 1;
  273. }
  274. else
  275. {
  276. node_reference[input_name] = node_reference[input_name] + 1;
  277. }
  278. }
  279. for (int j = 0; j < num_output; j++)
  280. {
  281. std::string output_name = get_mlir_value_uniq_id(operation.getResult(j));
  282. blob_names.insert(output_name);
  283. }
  284. }
  285. // remove node_reference entry with reference equals to one
  286. int splitncnn_blob_count = 0;
  287. std::map<std::string, int>::iterator it = node_reference.begin();
  288. while (it != node_reference.end())
  289. {
  290. if (it->second == 1)
  291. {
  292. node_reference.erase(it++);
  293. }
  294. else
  295. {
  296. splitncnn_blob_count += it->second;
  297. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  298. ++it;
  299. }
  300. }
  301. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  302. int internal_split = 0;
  303. // model op
  304. int g_opid = 0;
  305. for (const mlir::Operation& _operation : operations)
  306. {
  307. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  308. std::string op = operation.getName().getStringRef().str();
  309. int opid = g_opid++;
  310. int num_input = (int)operation.getNumOperands();
  311. int num_output = (int)operation.getNumResults();
  312. for (int i = 0; i < (int)operation.getNumOperands(); i++)
  313. {
  314. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(i));
  315. // check weight
  316. if (weights.find(input_name) != weights.end())
  317. {
  318. num_input--;
  319. }
  320. }
  321. if (op == "std.return")
  322. {
  323. fprintf(pp, "%-16s", "Noop");
  324. }
  325. else if (op == "ncnn.BinaryOp")
  326. {
  327. fprintf(pp, "%-16s", "BinaryOp");
  328. }
  329. else if (op == "ncnn.KerasConv2D")
  330. {
  331. fprintf(pp, "%-16s", "Convolution");
  332. }
  333. else if (op == "ncnn.KerasDense")
  334. {
  335. fprintf(pp, "%-16s", "InnerProduct");
  336. }
  337. else if (op == "ncnn.InstanceNorm")
  338. {
  339. fprintf(pp, "%-16s", "InstanceNorm");
  340. }
  341. else if (op == "ncnn.InstanceNormAffine")
  342. {
  343. fprintf(pp, "%-16s", "InstanceNorm");
  344. }
  345. else if (op == "tf.AddN")
  346. {
  347. fprintf(pp, "%-16s", "Eltwise");
  348. }
  349. else if (op == "tf.AddV2")
  350. {
  351. fprintf(pp, "%-16s", "BinaryOp");
  352. }
  353. else if (op == "tf.AvgPool")
  354. {
  355. fprintf(pp, "%-16s", "Pooling");
  356. }
  357. else if (op == "tf.BiasAdd")
  358. {
  359. fprintf(pp, "%-16s", "BinaryOp");
  360. }
  361. else if (op == "tf.ConcatV2")
  362. {
  363. fprintf(pp, "%-16s", "Concat");
  364. }
  365. else if (op == "tf.Const")
  366. {
  367. // check weight before BinaryOp
  368. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  369. if (binaryop_weights.find(output_name) != binaryop_weights.end())
  370. {
  371. fprintf(pp, "%-16s", "MemoryData");
  372. }
  373. else
  374. {
  375. continue;
  376. }
  377. }
  378. else if (op == "tf.Conv2D")
  379. {
  380. fprintf(pp, "%-16s", "Convolution");
  381. }
  382. else if (op == "tf.Conv2DBackpropInput")
  383. {
  384. fprintf(pp, "%-16s", "Deconvolution");
  385. }
  386. else if (op == "tf.DepthwiseConv2dNative")
  387. {
  388. fprintf(pp, "%-16s", "ConvolutionDepthWise");
  389. }
  390. else if (op == "tf.Identity")
  391. {
  392. fprintf(pp, "%-16s", "Noop");
  393. }
  394. else if (op == "tf.LeakyRelu")
  395. {
  396. fprintf(pp, "%-16s", "ReLU");
  397. }
  398. else if (op == "tf.MatMul")
  399. {
  400. int transpose_a = get_operation_attr_b(operation, "transpose_a");
  401. int transpose_b = get_operation_attr_b(operation, "transpose_b");
  402. if (transpose_a == 0 && transpose_b == 1)
  403. {
  404. // InnerProduct-like A * B + C
  405. fprintf(pp, "%-16s", "InnerProduct");
  406. }
  407. else
  408. {
  409. fprintf(pp, "%-16s", "Gemm");
  410. }
  411. }
  412. else if (op == "tf.Maximum")
  413. {
  414. fprintf(pp, "%-16s", "BinaryOp");
  415. }
  416. else if (op == "tf.MaxPool")
  417. {
  418. fprintf(pp, "%-16s", "Pooling");
  419. }
  420. else if (op == "tf.Mean")
  421. {
  422. std::string reduction_indices_name = get_mlir_value_uniq_id(operation.getOperand(1));
  423. const mlir::Attribute& R = weights[reduction_indices_name];
  424. std::vector<int> v = get_attr_ai(R);
  425. int keep_dims = get_operation_attr_b(operation, "keep_dims");
  426. if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
  427. {
  428. // global avg pooling style nhwc -> nc
  429. fprintf(pp, "%-16s", "Pooling");
  430. }
  431. else
  432. {
  433. fprintf(stderr, "tf.Mean is not global avg pooling\n");
  434. fprintf(pp, "%-16s", "Reduction");
  435. }
  436. }
  437. else if (op == "tf.Minimum")
  438. {
  439. fprintf(pp, "%-16s", "BinaryOp");
  440. }
  441. else if (op == "tf.Mul")
  442. {
  443. fprintf(pp, "%-16s", "BinaryOp");
  444. }
  445. else if (op == "tf.Pad")
  446. {
  447. fprintf(pp, "%-16s", "Padding");
  448. }
  449. else if (op == "tf.Placeholder")
  450. {
  451. fprintf(pp, "%-16s", "Input");
  452. }
  453. else if (op == "tf.Relu")
  454. {
  455. fprintf(pp, "%-16s", "ReLU");
  456. }
  457. else if (op == "tf.Relu6")
  458. {
  459. fprintf(pp, "%-16s", "Clip");
  460. }
  461. else if (op == "tf.Reshape")
  462. {
  463. fprintf(pp, "%-16s", "Reshape");
  464. }
  465. else if (op == "tf.ResizeBilinear")
  466. {
  467. fprintf(pp, "%-16s", "Interp");
  468. }
  469. else if (op == "tf.ResizeNearestNeighbor")
  470. {
  471. fprintf(pp, "%-16s", "Interp");
  472. }
  473. else if (op == "tf.Sigmoid")
  474. {
  475. fprintf(pp, "%-16s", "Sigmoid");
  476. }
  477. else if (op == "tf.Softmax")
  478. {
  479. fprintf(pp, "%-16s", "Softmax");
  480. }
  481. else if (op == "tf.StridedSlice")
  482. {
  483. fprintf(pp, "%-16s", "Crop");
  484. }
  485. else if (op == "tf.Sub")
  486. {
  487. fprintf(pp, "%-16s", "BinaryOp");
  488. }
  489. else if (op == "tf.Tanh")
  490. {
  491. fprintf(pp, "%-16s", "TanH");
  492. }
  493. else
  494. {
  495. // TODO
  496. fprintf(stderr, "%s not supported yet!\n", op.c_str());
  497. fprintf(pp, "%-16s", op.c_str());
  498. }
  499. fprintf(pp, " op_%d %d %d", opid, num_input, num_output);
  500. for (int i = 0; i < (int)operation.getNumOperands(); i++)
  501. {
  502. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(i));
  503. // check weight
  504. if (weights.find(input_name) != weights.end())
  505. {
  506. continue;
  507. }
  508. if (node_reference.find(input_name) != node_reference.end())
  509. {
  510. int refidx = node_reference[input_name] - 1;
  511. node_reference[input_name] = refidx;
  512. char splitsuffix[256];
  513. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  514. input_name = input_name + splitsuffix;
  515. }
  516. fprintf(pp, " %s", input_name.c_str());
  517. }
  518. for (int i = 0; i < num_output; i++)
  519. {
  520. std::string output_name = get_mlir_value_uniq_id(operation.getResult(i));
  521. fprintf(pp, " %s", output_name.c_str());
  522. }
  523. if (op == "std.return")
  524. {
  525. }
  526. else if (op == "ncnn.BinaryOp")
  527. {
  528. int op_type = get_operation_attr_i(operation, "op_type");
  529. int with_scalar = get_operation_attr_i(operation, "with_scalar");
  530. float b = get_operation_attr_f(operation, "b");
  531. fprintf(pp, " 0=%d", op_type);
  532. fprintf(pp, " 1=%d", with_scalar);
  533. fprintf(pp, " 2=%e", b);
  534. }
  535. else if (op == "ncnn.KerasConv2D")
  536. {
  537. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  538. std::string bias_name = get_mlir_value_uniq_id(operation.getOperand(2));
  539. const mlir::Attribute& W = weights[weight_name];
  540. const mlir::Attribute& B = weights[bias_name];
  541. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  542. // assert(shape.size() == 4)
  543. // kh-kw-inch-outch
  544. int kernel_size_h = shape[0];
  545. int kernel_size_w = shape[1];
  546. int num_input = shape[2];
  547. int num_output = shape[3];
  548. int weight_data_size = kernel_size_h * kernel_size_w * num_input * num_output;
  549. fprintf(pp, " 0=%d", num_output);
  550. fprintf(pp, " 1=%d", kernel_size_w);
  551. fprintf(pp, " 11=%d", kernel_size_h);
  552. fprintf(pp, " 6=%d", weight_data_size);
  553. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  554. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  555. std::string padding = get_operation_attr_s(operation, "padding");
  556. if (dilations.size() == 4)
  557. {
  558. fprintf(pp, " 2=%d", dilations[2]);
  559. fprintf(pp, " 12=%d", dilations[1]);
  560. }
  561. if (strides.size() == 4)
  562. {
  563. fprintf(pp, " 3=%d", strides[2]);
  564. fprintf(pp, " 13=%d", strides[1]);
  565. }
  566. if (padding == "EXPLICIT")
  567. {
  568. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  569. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  570. fprintf(pp, " 4=%d", explicit_paddings[4]);
  571. fprintf(pp, " 15=%d", explicit_paddings[5]);
  572. fprintf(pp, " 14=%d", explicit_paddings[2]);
  573. fprintf(pp, " 16=%d", explicit_paddings[3]);
  574. }
  575. else if (padding == "VALID")
  576. {
  577. fprintf(pp, " 4=%d", 0);
  578. }
  579. else if (padding == "SAME")
  580. {
  581. fprintf(pp, " 4=%d", -233);
  582. }
  583. fprintf(pp, " 5=1"); // bias_term
  584. std::vector<float> v = get_attr_af(W);
  585. std::vector<float> bv = get_attr_af(B);
  586. // reorder h-w-i-o to o-i-h-w
  587. {
  588. int quantize_tag = 0;
  589. fwrite(&quantize_tag, sizeof(int), 1, bp);
  590. float tmp;
  591. for (int p = 0; p < num_output; p++)
  592. {
  593. for (int q = 0; q < num_input; q++)
  594. {
  595. for (int i = 0; i < kernel_size_h; i++)
  596. {
  597. for (int j = 0; j < kernel_size_w; j++)
  598. {
  599. tmp = v[i * kernel_size_w * num_input * num_output + j * num_input * num_output + q * num_output + p];
  600. fwrite(&tmp, sizeof(float), 1, bp);
  601. }
  602. }
  603. }
  604. }
  605. }
  606. fwrite(bv.data(), sizeof(float), bv.size(), bp);
  607. }
  608. else if (op == "ncnn.KerasDense")
  609. {
  610. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  611. std::string bias_name = get_mlir_value_uniq_id(operation.getOperand(2));
  612. const mlir::Attribute& W = weights[weight_name];
  613. const mlir::Attribute& B = weights[bias_name];
  614. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  615. // assert(shape.size() == 2)
  616. // inch-outch
  617. int num_input = shape[0];
  618. int num_output = shape[1];
  619. int weight_data_size = shape[0] * shape[1];
  620. fprintf(pp, " 0=%d", num_output);
  621. fprintf(pp, " 1=1"); // bias_term
  622. fprintf(pp, " 2=%d", weight_data_size);
  623. std::vector<float> v = get_attr_af(W);
  624. std::vector<float> bv = get_attr_af(B);
  625. // reorder i-o to o-i
  626. {
  627. int quantize_tag = 0;
  628. fwrite(&quantize_tag, sizeof(int), 1, bp);
  629. float tmp;
  630. for (int p = 0; p < num_output; p++)
  631. {
  632. for (int q = 0; q < num_input; q++)
  633. {
  634. tmp = v[q * num_output + p];
  635. fwrite(&tmp, sizeof(float), 1, bp);
  636. }
  637. }
  638. }
  639. fwrite(bv.data(), sizeof(float), bv.size(), bp);
  640. }
  641. else if (op == "ncnn.InstanceNorm")
  642. {
  643. float eps = get_operation_attr_f(operation, "epsilon");
  644. fprintf(pp, " 0=0"); // channels
  645. fprintf(pp, " 1=%e", eps);
  646. fprintf(pp, " 2=0"); // affine
  647. }
  648. else if (op == "ncnn.InstanceNormAffine")
  649. {
  650. float eps = get_operation_attr_f(operation, "epsilon");
  651. std::string gamma_name = get_mlir_value_uniq_id(operation.getOperand(1));
  652. std::string beta_name = get_mlir_value_uniq_id(operation.getOperand(2));
  653. const mlir::Attribute& G = weights[gamma_name];
  654. const mlir::Attribute& B = weights[beta_name];
  655. std::vector<float> gv = get_attr_af(G);
  656. std::vector<float> bv = get_attr_af(B);
  657. int channels = gv.size();
  658. fprintf(pp, " 0=%d", channels);
  659. fprintf(pp, " 1=%e", eps);
  660. fprintf(pp, " 2=1"); // affine
  661. fwrite(gv.data(), sizeof(float), gv.size(), bp);
  662. fwrite(bv.data(), sizeof(float), bv.size(), bp);
  663. }
  664. else if (op == "tf.AddN")
  665. {
  666. int op_type = 1;
  667. fprintf(pp, " 0=%d", op_type);
  668. }
  669. else if (op == "tf.AddV2")
  670. {
  671. int op_type = 0;
  672. fprintf(pp, " 0=%d", op_type);
  673. }
  674. else if (op == "tf.AvgPool")
  675. {
  676. std::vector<int> ksize = get_operation_attr_ai(operation, "ksize");
  677. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  678. std::string padding = get_operation_attr_s(operation, "padding");
  679. fprintf(pp, " 0=1"); // avg pool
  680. if (ksize.size() == 4)
  681. {
  682. fprintf(pp, " 1=%d", ksize[2]);
  683. fprintf(pp, " 11=%d", ksize[1]);
  684. }
  685. if (strides.size() == 4)
  686. {
  687. fprintf(pp, " 2=%d", strides[2]);
  688. fprintf(pp, " 12=%d", strides[1]);
  689. }
  690. int pad_mode = 1;
  691. if (padding == "VALID")
  692. {
  693. pad_mode = 1;
  694. }
  695. else if (padding == "SAME")
  696. {
  697. pad_mode = 2;
  698. }
  699. fprintf(pp, " 5=%d", pad_mode);
  700. }
  701. else if (op == "tf.ConcatV2")
  702. {
  703. std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(operation.getNumOperands() - 1));
  704. const mlir::Attribute& A = weights[axis_name];
  705. int axis = get_attr_ai(A)[0];
  706. // axis nhc to nhw
  707. // axis nhwc to nchw
  708. int dims = operation.getOperand(0).getType().cast<mlir::RankedTensorType>().getShape().size();
  709. if (dims == 2 && axis == 1)
  710. {
  711. axis = 0;
  712. }
  713. if (dims == 3 && axis == 1)
  714. {
  715. axis = 1;
  716. }
  717. if (dims == 3 && axis == 2)
  718. {
  719. axis = 0;
  720. }
  721. if (dims == 4 && axis == 1)
  722. {
  723. axis = 1;
  724. }
  725. if (dims == 4 && axis == 2)
  726. {
  727. axis = 2;
  728. }
  729. if (dims == 4 && axis == 3)
  730. {
  731. axis = 0;
  732. }
  733. fprintf(pp, " 0=%d", axis);
  734. }
  735. else if (op == "tf.Const")
  736. {
  737. // check weight before BinaryOp
  738. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  739. if (binaryop_weights.find(output_name) != binaryop_weights.end())
  740. {
  741. const mlir::Attribute& M = binaryop_weights[output_name];
  742. llvm::ArrayRef<int64_t> shape = M.getType().cast<mlir::RankedTensorType>().getShape();
  743. // c wc hwc
  744. if (shape.size() == 0)
  745. {
  746. // scalar
  747. fprintf(pp, " 0=1");
  748. }
  749. else if (shape.size() == 1)
  750. {
  751. fprintf(pp, " 0=%d", (int)shape[0]);
  752. }
  753. else if (shape.size() == 2)
  754. {
  755. fprintf(pp, " 0=%d", (int)shape[1]);
  756. fprintf(pp, " 1=%d", (int)shape[0]);
  757. }
  758. else if (shape.size() == 3)
  759. {
  760. fprintf(pp, " 0=%d", (int)shape[1]);
  761. fprintf(pp, " 1=%d", (int)shape[0]);
  762. fprintf(pp, " 2=%d", (int)shape[2]);
  763. }
  764. std::vector<float> v = get_attr_af(M);
  765. if (shape.size() != 3)
  766. {
  767. fwrite(v.data(), sizeof(float), v.size(), bp);
  768. }
  769. else
  770. {
  771. int w = (int)shape[1];
  772. int h = (int)shape[0];
  773. int c = (int)shape[2];
  774. float tmp;
  775. // h-w-c to c-h-w
  776. for (int p = 0; p < c; p++)
  777. {
  778. for (int i = 0; i < h; i++)
  779. {
  780. for (int j = 0; j < w; j++)
  781. {
  782. tmp = v[i * w * c + j * c + p];
  783. fwrite(&tmp, sizeof(float), 1, bp);
  784. }
  785. }
  786. }
  787. }
  788. }
  789. }
  790. else if (op == "tf.Conv2D")
  791. {
  792. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  793. const mlir::Attribute& W = weights[weight_name];
  794. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  795. // assert(shape.size() == 4)
  796. // kh-kw-inch-outch
  797. int kernel_size_h = shape[0];
  798. int kernel_size_w = shape[1];
  799. int num_input = shape[2];
  800. int num_output = shape[3];
  801. int weight_data_size = kernel_size_h * kernel_size_w * num_input * num_output;
  802. fprintf(pp, " 0=%d", num_output);
  803. fprintf(pp, " 1=%d", kernel_size_w);
  804. fprintf(pp, " 11=%d", kernel_size_h);
  805. fprintf(pp, " 6=%d", weight_data_size);
  806. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  807. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  808. std::string padding = get_operation_attr_s(operation, "padding");
  809. if (dilations.size() == 4)
  810. {
  811. fprintf(pp, " 2=%d", dilations[2]);
  812. fprintf(pp, " 12=%d", dilations[1]);
  813. }
  814. if (strides.size() == 4)
  815. {
  816. fprintf(pp, " 3=%d", strides[2]);
  817. fprintf(pp, " 13=%d", strides[1]);
  818. }
  819. if (padding == "EXPLICIT")
  820. {
  821. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  822. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  823. fprintf(pp, " 4=%d", explicit_paddings[4]);
  824. fprintf(pp, " 15=%d", explicit_paddings[5]);
  825. fprintf(pp, " 14=%d", explicit_paddings[2]);
  826. fprintf(pp, " 16=%d", explicit_paddings[3]);
  827. }
  828. else if (padding == "VALID")
  829. {
  830. fprintf(pp, " 4=%d", 0);
  831. }
  832. else if (padding == "SAME")
  833. {
  834. fprintf(pp, " 4=%d", -233);
  835. }
  836. std::vector<float> v = get_attr_af(W);
  837. // reorder h-w-i-o to o-i-h-w
  838. {
  839. int quantize_tag = 0;
  840. fwrite(&quantize_tag, sizeof(int), 1, bp);
  841. float tmp;
  842. for (int p = 0; p < num_output; p++)
  843. {
  844. for (int q = 0; q < num_input; q++)
  845. {
  846. for (int i = 0; i < kernel_size_h; i++)
  847. {
  848. for (int j = 0; j < kernel_size_w; j++)
  849. {
  850. tmp = v[i * kernel_size_w * num_input * num_output + j * num_input * num_output + q * num_output + p];
  851. fwrite(&tmp, sizeof(float), 1, bp);
  852. }
  853. }
  854. }
  855. }
  856. }
  857. }
  858. else if (op == "tf.Conv2DBackpropInput")
  859. {
  860. std::string output_shape_name = get_mlir_value_uniq_id(operation.getOperand(0));
  861. const std::vector<int> output_shape = get_attr_ai(weights[output_shape_name]);
  862. // assert(output_shape.size() == 4)
  863. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  864. const mlir::Attribute& W = weights[weight_name];
  865. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  866. // assert(shape.size() == 4)
  867. // kh-kw-outch-inch
  868. int kernel_size_h = shape[0];
  869. int kernel_size_w = shape[1];
  870. int num_output = shape[2];
  871. int num_input = shape[3];
  872. int weight_data_size = kernel_size_h * kernel_size_w * num_input * num_output;
  873. fprintf(pp, " 0=%d", num_output);
  874. fprintf(pp, " 1=%d", kernel_size_w);
  875. fprintf(pp, " 11=%d", kernel_size_h);
  876. fprintf(pp, " 6=%d", weight_data_size);
  877. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  878. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  879. std::string padding = get_operation_attr_s(operation, "padding");
  880. if (dilations.size() == 4)
  881. {
  882. fprintf(pp, " 2=%d", dilations[2]);
  883. fprintf(pp, " 12=%d", dilations[1]);
  884. }
  885. if (strides.size() == 4)
  886. {
  887. fprintf(pp, " 3=%d", strides[2]);
  888. fprintf(pp, " 13=%d", strides[1]);
  889. }
  890. if (padding == "EXPLICIT")
  891. {
  892. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  893. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  894. fprintf(pp, " 4=%d", explicit_paddings[4]);
  895. fprintf(pp, " 15=%d", explicit_paddings[5]);
  896. fprintf(pp, " 14=%d", explicit_paddings[2]);
  897. fprintf(pp, " 16=%d", explicit_paddings[3]);
  898. }
  899. else if (padding == "VALID")
  900. {
  901. fprintf(pp, " 4=%d", 0);
  902. }
  903. else if (padding == "SAME")
  904. {
  905. fprintf(pp, " 4=%d", -233);
  906. fprintf(pp, " 20=%d", output_shape[2]);
  907. fprintf(pp, " 21=%d", output_shape[1]);
  908. }
  909. std::vector<float> v = get_attr_af(W);
  910. // reorder h-w-o-i to o-i-h-w
  911. {
  912. int quantize_tag = 0;
  913. fwrite(&quantize_tag, sizeof(int), 1, bp);
  914. float tmp;
  915. for (int p = 0; p < num_output; p++)
  916. {
  917. for (int q = 0; q < num_input; q++)
  918. {
  919. for (int i = 0; i < kernel_size_h; i++)
  920. {
  921. for (int j = 0; j < kernel_size_w; j++)
  922. {
  923. tmp = v[i * kernel_size_w * num_output * num_input + j * num_output * num_input + p * num_input + q];
  924. fwrite(&tmp, sizeof(float), 1, bp);
  925. }
  926. }
  927. }
  928. }
  929. }
  930. }
  931. else if (op == "tf.DepthwiseConv2dNative")
  932. {
  933. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  934. const mlir::Attribute& W = weights[weight_name];
  935. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  936. // assert(shape.size() == 4)
  937. // kh-kw-inch-cm
  938. int kernel_size_h = shape[0];
  939. int kernel_size_w = shape[1];
  940. int num_input = shape[2];
  941. int channel_multiplier = shape[3];
  942. int num_output = num_input * channel_multiplier;
  943. int group = num_input;
  944. int weight_data_size = kernel_size_h * kernel_size_w * num_input * channel_multiplier;
  945. fprintf(pp, " 0=%d", num_output);
  946. fprintf(pp, " 1=%d", kernel_size_w);
  947. fprintf(pp, " 11=%d", kernel_size_h);
  948. fprintf(pp, " 6=%d", weight_data_size);
  949. fprintf(pp, " 7=%d", group);
  950. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  951. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  952. std::string padding = get_operation_attr_s(operation, "padding");
  953. if (dilations.size() == 4)
  954. {
  955. fprintf(pp, " 2=%d", dilations[2]);
  956. fprintf(pp, " 12=%d", dilations[1]);
  957. }
  958. if (strides.size() == 4)
  959. {
  960. fprintf(pp, " 3=%d", strides[2]);
  961. fprintf(pp, " 13=%d", strides[1]);
  962. }
  963. if (padding == "EXPLICIT")
  964. {
  965. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  966. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  967. fprintf(pp, " 4=%d", explicit_paddings[4]);
  968. fprintf(pp, " 15=%d", explicit_paddings[5]);
  969. fprintf(pp, " 14=%d", explicit_paddings[2]);
  970. fprintf(pp, " 16=%d", explicit_paddings[3]);
  971. }
  972. else if (padding == "VALID")
  973. {
  974. fprintf(pp, " 4=%d", 0);
  975. }
  976. else if (padding == "SAME")
  977. {
  978. fprintf(pp, " 4=%d", -233);
  979. }
  980. std::vector<float> v = get_attr_af(W);
  981. // reorder h-w-i-cm to i-cm-h-w
  982. {
  983. int quantize_tag = 0;
  984. fwrite(&quantize_tag, sizeof(int), 1, bp);
  985. float tmp;
  986. for (int p = 0; p < num_input; p++)
  987. {
  988. for (int q = 0; q < channel_multiplier; q++)
  989. {
  990. for (int i = 0; i < kernel_size_h; i++)
  991. {
  992. for (int j = 0; j < kernel_size_w; j++)
  993. {
  994. tmp = v[i * kernel_size_w * channel_multiplier * num_input + j * channel_multiplier * num_input + p * channel_multiplier + q];
  995. fwrite(&tmp, sizeof(float), 1, bp);
  996. }
  997. }
  998. }
  999. }
  1000. }
  1001. }
  1002. else if (op == "tf.Identity")
  1003. {
  1004. }
  1005. else if (op == "tf.LeakyRelu")
  1006. {
  1007. float alpha = get_operation_attr_f(operation, "alpha");
  1008. fprintf(pp, " 0=%e", alpha);
  1009. }
  1010. else if (op == "tf.MatMul")
  1011. {
  1012. int transpose_a = get_operation_attr_b(operation, "transpose_a");
  1013. int transpose_b = get_operation_attr_b(operation, "transpose_b");
  1014. if (transpose_a == 0 && transpose_b == 1)
  1015. {
  1016. // InnerProduct-like A * B + C
  1017. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1018. const mlir::Attribute& W = weights[weight_name];
  1019. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  1020. // assert(shape.size() == 2)
  1021. // inch-outch
  1022. int num_input = shape[0];
  1023. int num_output = shape[1];
  1024. int weight_data_size = shape[0] * shape[1];
  1025. fprintf(pp, " 0=%d", num_output);
  1026. fprintf(pp, " 2=%d", weight_data_size);
  1027. std::vector<float> v = get_attr_af(W);
  1028. // reorder i-o to o-i
  1029. {
  1030. int quantize_tag = 0;
  1031. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1032. float tmp;
  1033. for (int p = 0; p < num_output; p++)
  1034. {
  1035. for (int q = 0; q < num_input; q++)
  1036. {
  1037. tmp = v[q * num_output + p];
  1038. fwrite(&tmp, sizeof(float), 1, bp);
  1039. }
  1040. }
  1041. }
  1042. }
  1043. else
  1044. {
  1045. // gemm
  1046. fprintf(pp, " 0=1.0"); // alpha
  1047. fprintf(pp, " 1=1.0"); // beta
  1048. fprintf(pp, " 2=%d", transpose_a);
  1049. fprintf(pp, " 3=%d", transpose_b);
  1050. }
  1051. }
  1052. else if (op == "tf.Maximum")
  1053. {
  1054. int op_type = 4;
  1055. fprintf(pp, " 0=%d", op_type);
  1056. }
  1057. else if (op == "tf.MaxPool")
  1058. {
  1059. std::vector<int> ksize = get_operation_attr_ai(operation, "ksize");
  1060. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  1061. std::string padding = get_operation_attr_s(operation, "padding");
  1062. fprintf(pp, " 0=0"); // max pool
  1063. if (ksize.size() == 4)
  1064. {
  1065. fprintf(pp, " 1=%d", ksize[2]);
  1066. fprintf(pp, " 11=%d", ksize[1]);
  1067. }
  1068. if (strides.size() == 4)
  1069. {
  1070. fprintf(pp, " 2=%d", strides[2]);
  1071. fprintf(pp, " 12=%d", strides[1]);
  1072. }
  1073. int pad_mode = 1;
  1074. if (padding == "VALID")
  1075. {
  1076. pad_mode = 1;
  1077. }
  1078. else if (padding == "SAME")
  1079. {
  1080. pad_mode = 2;
  1081. }
  1082. fprintf(pp, " 5=%d", pad_mode);
  1083. }
  1084. else if (op == "tf.Mean")
  1085. {
  1086. std::string reduction_indices_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1087. const mlir::Attribute& R = weights[reduction_indices_name];
  1088. std::vector<int> v = get_attr_ai(R);
  1089. int keep_dims = get_operation_attr_b(operation, "keep_dims");
  1090. if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
  1091. {
  1092. // global avg pooling style nhwc -> nc
  1093. int pool = 1;
  1094. int global_pool = 1;
  1095. fprintf(pp, " 0=%d", pool);
  1096. fprintf(pp, " 4=%d", global_pool);
  1097. }
  1098. else
  1099. {
  1100. // TODO
  1101. }
  1102. }
  1103. else if (op == "tf.Minimum")
  1104. {
  1105. int op_type = 5;
  1106. fprintf(pp, " 0=%d", op_type);
  1107. }
  1108. else if (op == "tf.Mul")
  1109. {
  1110. int op_type = 2;
  1111. fprintf(pp, " 0=%d", op_type);
  1112. }
  1113. else if (op == "tf.Pad")
  1114. {
  1115. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1116. const mlir::Attribute& P = weights[weight_name];
  1117. std::vector<int> v = get_attr_ai(P);
  1118. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  1119. fprintf(pp, " 0=%d", v[2]);
  1120. fprintf(pp, " 1=%d", v[3]);
  1121. fprintf(pp, " 2=%d", v[4]);
  1122. fprintf(pp, " 3=%d", v[5]);
  1123. }
  1124. else if (op == "tf.Placeholder")
  1125. {
  1126. }
  1127. else if (op == "tf.Relu")
  1128. {
  1129. }
  1130. else if (op == "tf.Relu6")
  1131. {
  1132. float min = 0.f;
  1133. float max = 6.f;
  1134. fprintf(pp, " 0=%e", min);
  1135. fprintf(pp, " 1=%e", max);
  1136. }
  1137. else if (op == "tf.Reshape")
  1138. {
  1139. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1140. const mlir::Attribute& S = weights[weight_name];
  1141. std::vector<int> v = get_attr_ai(S);
  1142. int size = v.size();
  1143. // n h w c
  1144. // n h c
  1145. // n c
  1146. if (size == 4)
  1147. {
  1148. fprintf(pp, " 0=%d 1=%d 2=%d", v[2], v[1], v[3]);
  1149. }
  1150. if (size == 3)
  1151. {
  1152. fprintf(pp, " 0=%d 1=%d 2=-233", v[1], v[2]);
  1153. }
  1154. if (size == 2)
  1155. {
  1156. fprintf(pp, " 0=%d 1=-233 2=-233", v[1]);
  1157. }
  1158. // FIXME may not always be the case
  1159. fprintf(pp, " 3=1");
  1160. }
  1161. else if (op == "tf.ResizeBilinear")
  1162. {
  1163. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1164. const mlir::Attribute& P = weights[weight_name];
  1165. std::vector<int> size = get_attr_ai(P);
  1166. int align_corners = get_operation_attr_b(operation, "align_corners");
  1167. int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers");
  1168. if (!(align_corners == 0 && half_pixel_centers == 1))
  1169. {
  1170. fprintf(stderr, "Unsupported ResizeBilinear align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers);
  1171. }
  1172. fprintf(pp, " 0=2"); // bilinear
  1173. fprintf(pp, " 3=%d 4=%d", size[1], size[0]);
  1174. }
  1175. else if (op == "tf.ResizeNearestNeighbor")
  1176. {
  1177. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1178. const mlir::Attribute& P = weights[weight_name];
  1179. std::vector<int> size = get_attr_ai(P);
  1180. int align_corners = get_operation_attr_b(operation, "align_corners");
  1181. int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers");
  1182. if (!(align_corners == 0 && half_pixel_centers == 1))
  1183. {
  1184. fprintf(stderr, "Unsupported ResizeNearestNeighbor align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers);
  1185. }
  1186. fprintf(pp, " 0=1"); // nearest
  1187. fprintf(pp, " 3=%d 4=%d", size[1], size[0]);
  1188. }
  1189. else if (op == "tf.Sigmoid")
  1190. {
  1191. }
  1192. else if (op == "tf.Softmax")
  1193. {
  1194. }
  1195. else if (op == "tf.StridedSlice")
  1196. {
  1197. std::string begin_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1198. std::string end_name = get_mlir_value_uniq_id(operation.getOperand(2));
  1199. std::string strides_name = get_mlir_value_uniq_id(operation.getOperand(3));
  1200. const mlir::Attribute& B = weights[begin_name];
  1201. const mlir::Attribute& E = weights[end_name];
  1202. const mlir::Attribute& S = weights[strides_name];
  1203. std::vector<int> begin = get_attr_ai(B);
  1204. std::vector<int> end = get_attr_ai(E);
  1205. std::vector<int> strides = get_attr_ai(S);
  1206. int begin_mask = get_operation_attr_i(operation, "begin_mask");
  1207. int end_mask = get_operation_attr_i(operation, "end_mask");
  1208. int ellipsis_mask = get_operation_attr_i(operation, "ellipsis_mask");
  1209. int new_axis_mask = get_operation_attr_i(operation, "new_axis_mask");
  1210. int shrink_axis_mask = get_operation_attr_i(operation, "shrink_axis_mask");
  1211. int dims = strides.size();
  1212. // assert strides == 1
  1213. for (int i = 0; i < dims; i++)
  1214. {
  1215. if (strides[i] != 1)
  1216. fprintf(stderr, "Unsupported StridedSlice strides !\n");
  1217. }
  1218. for (int i = 0; i < dims; i++)
  1219. {
  1220. // TODO strides[i] < 0
  1221. if (begin_mask & (1 << i))
  1222. {
  1223. begin[i] = 0;
  1224. }
  1225. if (end_mask & (1 << i))
  1226. {
  1227. end[i] = -233;
  1228. }
  1229. if (ellipsis_mask & (1 << i))
  1230. {
  1231. begin[i] = 0;
  1232. end[i] = -233;
  1233. }
  1234. }
  1235. if (new_axis_mask)
  1236. {
  1237. fprintf(stderr, "Unsupported StridedSlice new_axis_mask !\n");
  1238. }
  1239. if (shrink_axis_mask)
  1240. {
  1241. fprintf(stderr, "Unsupported StridedSlice shrink_axis_mask !\n");
  1242. }
  1243. // n h w c
  1244. // n h c
  1245. // n c
  1246. if (dims == 4)
  1247. {
  1248. fprintf(pp, " -23309=3,%d,%d,%d", begin[3], begin[1], begin[2]);
  1249. fprintf(pp, " -23310=3,%d,%d,%d", end[3], end[1], end[2]);
  1250. }
  1251. if (dims == 3)
  1252. {
  1253. fprintf(pp, " -23309=2,%d,%d", begin[2], begin[1]);
  1254. fprintf(pp, " -23310=2,%d,%d", end[2], end[1]);
  1255. }
  1256. if (dims == 2)
  1257. {
  1258. fprintf(pp, " -23309=1,%d", begin[1]);
  1259. fprintf(pp, " -23310=1,%d", end[1]);
  1260. }
  1261. }
  1262. else if (op == "tf.Sub")
  1263. {
  1264. int op_type = 1;
  1265. fprintf(pp, " 0=%d", op_type);
  1266. }
  1267. else if (op == "tf.Tanh")
  1268. {
  1269. }
  1270. #if 0
  1271. for (const mlir::NamedAttribute& attr : operation.getAttrs())
  1272. {
  1273. const mlir::Identifier& identifier = attr.first;
  1274. const mlir::Attribute& attr = attr.second;
  1275. fprintf(pp, " %s=", identifier.c_str());
  1276. if (attr.isa<mlir::AffineMapAttr>())
  1277. {
  1278. fprintf(pp, "AffineMap");
  1279. }
  1280. if (attr.isa<mlir::ArrayAttr>())
  1281. {
  1282. // fprintf(pp, "Array");
  1283. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  1284. int array_size = a.getValue().size();
  1285. for (int t=0; t<array_size; t++)
  1286. {
  1287. if (a[t].isa<mlir::IntegerAttr>())
  1288. {
  1289. int64_t ii = a[t].cast<mlir::IntegerAttr>().getInt();
  1290. fprintf(pp, "%lld,", ii);
  1291. }
  1292. }
  1293. }
  1294. if (attr.isa<mlir::BoolAttr>())
  1295. {
  1296. // fprintf(pp, "Bool");
  1297. mlir::BoolAttr a = attr.cast<mlir::BoolAttr>();
  1298. fprintf(pp, "%d", a.getValue() ? 1 : 0);
  1299. }
  1300. if (attr.isa<mlir::DictionaryAttr>())
  1301. {
  1302. fprintf(pp, "Dictionary");
  1303. }
  1304. if (attr.isa<mlir::FloatAttr>())
  1305. {
  1306. fprintf(pp, "Float");
  1307. }
  1308. if (attr.isa<mlir::IntegerAttr>())
  1309. {
  1310. fprintf(pp, "Integer");
  1311. }
  1312. if (attr.isa<mlir::IntegerSetAttr>())
  1313. {
  1314. fprintf(pp, "IntegerSet");
  1315. }
  1316. if (attr.isa<mlir::OpaqueAttr>())
  1317. {
  1318. fprintf(pp, "Opaque");
  1319. }
  1320. if (attr.isa<mlir::StringAttr>())
  1321. {
  1322. // fprintf(pp, "String");
  1323. mlir::StringAttr s = attr.cast<mlir::StringAttr>();
  1324. fprintf(pp, "%s", s.getValue().empty() ? "" : s.getValue().data());
  1325. }
  1326. if (attr.isa<mlir::SymbolRefAttr>())
  1327. {
  1328. fprintf(pp, "SymbolRef");
  1329. }
  1330. if (attr.isa<mlir::FlatSymbolRefAttr>())
  1331. {
  1332. fprintf(pp, "FlatSymbolRef");
  1333. }
  1334. if (attr.isa<mlir::TypeAttr>())
  1335. {
  1336. fprintf(pp, "Type");
  1337. }
  1338. if (attr.isa<mlir::UnitAttr>())
  1339. {
  1340. fprintf(pp, "Unit");
  1341. }
  1342. if (attr.isa<mlir::ElementsAttr>())
  1343. {
  1344. fprintf(pp, "Elements");
  1345. }
  1346. if (attr.isa<mlir::DenseElementsAttr>())
  1347. {
  1348. fprintf(pp, "DenseElements");
  1349. }
  1350. if (attr.isa<mlir::DenseFPElementsAttr>())
  1351. {
  1352. fprintf(pp, "DenseFPElements");
  1353. }
  1354. if (attr.isa<mlir::DenseIntElementsAttr>())
  1355. {
  1356. fprintf(pp, "DenseIntElements");
  1357. }
  1358. if (attr.isa<mlir::OpaqueElementsAttr>())
  1359. {
  1360. fprintf(pp, "OpaqueElements");
  1361. }
  1362. if (attr.isa<mlir::SparseElementsAttr>())
  1363. {
  1364. fprintf(pp, "SparseElements");
  1365. }
  1366. if (attr.isa<mlir::SplatElementsAttr>())
  1367. {
  1368. fprintf(pp, "SplatElements");
  1369. }
  1370. }
  1371. #endif
  1372. fprintf(pp, "\n");
  1373. for (int j = 0; j < num_output; j++)
  1374. {
  1375. std::string output_name = get_mlir_value_uniq_id(operation.getResult(j));
  1376. if (node_reference.find(output_name) != node_reference.end())
  1377. {
  1378. int refcount = node_reference[output_name];
  1379. if (refcount > 1)
  1380. {
  1381. char splitname[256];
  1382. sprintf(splitname, "splitncnn_%d", internal_split);
  1383. fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
  1384. fprintf(pp, " %s", output_name.c_str());
  1385. for (int k = 0; k < refcount; k++)
  1386. {
  1387. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
  1388. }
  1389. fprintf(pp, "\n");
  1390. internal_split++;
  1391. }
  1392. }
  1393. }
  1394. }
  1395. fclose(pp);
  1396. fclose(bp);
  1397. return 0;
  1398. }