You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlir2ncnn.cpp 56 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #include <map>
  16. #include <set>
  17. #include <llvm/ADT/APFloat.h>
  18. #include <llvm/ADT/APInt.h>
  19. #include <llvm/ADT/ArrayRef.h>
  20. #include <llvm/ADT/STLExtras.h>
  21. #include <llvm/ADT/SmallVector.h>
  22. #include <llvm/ADT/StringRef.h>
  23. #include <llvm/Support/FormatVariadic.h>
  24. #include <llvm/Support/MathExtras.h>
  25. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  26. #include <mlir/Dialect/Traits.h>
  27. #include <mlir/IR/Attributes.h>
  28. #include <mlir/IR/Builders.h>
  29. #include <mlir/IR/Dialect.h>
  30. #include <mlir/IR/DialectImplementation.h>
  31. #include <mlir/IR/Function.h>
  32. #include <mlir/IR/Location.h>
  33. #include <mlir/IR/MLIRContext.h>
  34. #include <mlir/IR/Module.h>
  35. #include <mlir/IR/OpDefinition.h>
  36. #include <mlir/IR/OpImplementation.h>
  37. #include <mlir/IR/Operation.h>
  38. #include <mlir/IR/OperationSupport.h>
  39. #include <mlir/IR/PatternMatch.h>
  40. #include <mlir/IR/StandardTypes.h>
  41. #include <mlir/IR/TypeUtilities.h>
  42. #include <mlir/IR/Types.h>
  43. #include <mlir/IR/Value.h>
  44. #include <mlir/IR/Verifier.h>
  45. #include <mlir/Interfaces/CallInterfaces.h>
  46. #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
  47. #include <mlir/Interfaces/InferTypeOpInterface.h>
  48. #include <mlir/Interfaces/LoopLikeInterface.h>
  49. #include <mlir/Interfaces/SideEffectInterfaces.h>
  50. #include <mlir/Parser.h>
  51. #include <mlir/Support/LogicalResult.h>
  52. #include <mlir/Transforms/InliningUtils.h>
  53. #include "tf_attributes.h"
  54. #include "tf_side_effects.h"
  55. #include "tf_traits.h"
  56. namespace mlir {
  57. static LogicalResult Verify(...)
  58. {
  59. return success();
  60. }
  61. static LogicalResult VerifyPartitionedCall(...)
  62. {
  63. return success();
  64. }
  65. static LogicalResult VerifyStridedSliceBase(...)
  66. {
  67. return success();
  68. }
  69. static LogicalResult VerifyUnsortedSegmentReduction(...)
  70. {
  71. return success();
  72. }
  73. namespace TF {
  74. #include "tf_op_interfaces.h.inc"
  75. class TensorFlowDialect : public mlir::Dialect
  76. {
  77. public:
  78. TensorFlowDialect(mlir::MLIRContext* context);
  79. Attribute parseAttribute(DialectAsmParser& parser, Type type) const override;
  80. // Parse a type registered to this dialect.
  81. Type parseType(DialectAsmParser& parser) const override;
  82. // Parses resource type with potential subtypes.
  83. Type ParseResourceType(DialectAsmParser& parser, Location loc) const;
  84. // Parse and print variant type. It may have subtypes inferred using shape
  85. // inference.
  86. Type ParseVariantType(DialectAsmParser& parser, Location loc) const;
  87. // Registered hook to materialize a constant operation from a given attribute
  88. // value with the desired resultant type.
  89. Operation* materializeConstant(OpBuilder& builder, Attribute value, Type type,
  90. Location loc) override;
  91. };
  92. #define GET_OP_CLASSES
  93. #include "tf_ops.h.inc"
  94. namespace {
  95. struct TFInlinerInterface : public DialectInlinerInterface
  96. {
  97. using DialectInlinerInterface::DialectInlinerInterface;
  98. //===--------------------------------------------------------------------===//
  99. // Analysis Hooks
  100. //===--------------------------------------------------------------------===//
  101. // Defines the legality of inlining TF operations.
  102. bool isLegalToInline(Operation*, Region*,
  103. BlockAndValueMapping&) const final
  104. {
  105. // TODO(riverriddle) For now, enable inlining all operations. This isn't
  106. // correct in the face of operations that cannot be duplicated, but this
  107. // requires more intricate side-effect modeling.
  108. return true;
  109. }
  110. //===--------------------------------------------------------------------===//
  111. // Transformation Hooks
  112. //===--------------------------------------------------------------------===//
  113. // Attempts to materialize a conversion for a type mismatch between a call
  114. // from this dialect, and a callable region. This method should generate an
  115. // operation that takes 'input' as the only operand, and produces a single
  116. // result of 'resultType'. If a conversion can not be generated, nullptr
  117. // should be returned.
  118. Operation* materializeCallConversion(OpBuilder& builder, Value input,
  119. Type result_type,
  120. Location conversion_loc) const final
  121. {
  122. if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
  123. return nullptr;
  124. return builder.create<TF::CastOp>(conversion_loc, result_type, input,
  125. /*truncate=*/builder.getBoolAttr(false));
  126. }
  127. };
  128. } // end anonymous namespace
  129. TensorFlowDialect::TensorFlowDialect(mlir::MLIRContext* context)
  130. : mlir::Dialect("tf", context)
  131. {
  132. addOperations<
  133. #define GET_OP_LIST
  134. #include "tf_ops.cpp.inc"
  135. >();
  136. addTypes<
  137. #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
  138. #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
  139. #include "tf_types.def"
  140. >();
  141. addInterfaces<TFInlinerInterface>();
  142. addAttributes<ShapeAttr, FuncAttr>();
  143. // Support unknown operations because not all TensorFlow operations are
  144. // registered.
  145. allowUnknownOperations();
  146. }
  147. ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
  148. {
  149. auto emit_error = [&, spec]() {
  150. emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
  151. return nullptr;
  152. };
  153. if (!spec.consume_front("shape<")) return emit_error();
  154. if (spec.consume_front("*>"))
  155. return mlir::TF::ShapeAttr::get(context, llvm::None);
  156. SmallVector<int64_t, 4> shape;
  157. while (!spec.consume_front(">"))
  158. {
  159. int64_t dim;
  160. if (spec.consume_front("?"))
  161. dim = -1;
  162. else if (spec.consumeInteger(10, dim) || dim < 0)
  163. return emit_error();
  164. spec.consume_front("x");
  165. shape.push_back(dim);
  166. }
  167. return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
  168. }
  169. // Parses a #tf.func attribute of the following format:
  170. //
  171. // #tf.func<@symbol, {attr = "value"}>
  172. //
  173. // where the first element is a SymbolRefAttr and the second element is a
  174. // DictionaryAttr.
  175. FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
  176. {
  177. auto emit_error = [&, spec]() {
  178. emitError(loc, "invalid TensorFlow func attribute: ") << spec;
  179. return nullptr;
  180. };
  181. if (!spec.consume_front("func<")) return emit_error();
  182. size_t func_name_num_read = 0;
  183. Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
  184. if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
  185. return emit_error();
  186. spec = spec.drop_front(func_name_num_read);
  187. if (!spec.consume_front(", ")) return emit_error();
  188. size_t func_attrs_num_read = 0;
  189. Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
  190. if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
  191. return emit_error();
  192. spec = spec.drop_front(func_attrs_num_read);
  193. if (!spec.consume_front(">")) return emit_error();
  194. return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
  195. func_attrs_attr.cast<DictionaryAttr>());
  196. }
  197. Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
  198. Type type) const
  199. {
  200. auto spec = parser.getFullSymbolSpec();
  201. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  202. if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
  203. if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
  204. return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
  205. }
  206. // Parses a type registered to this dialect.
  207. Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
  208. {
  209. StringRef data;
  210. if (parser.parseKeyword(&data)) return Type();
  211. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  212. auto typeKind = llvm::StringSwitch<unsigned>(data)
  213. #define HANDLE_TF_TYPE(tftype, enumerant, name) \
  214. .Case(name, TensorFlowTypes::enumerant)
  215. // Custom TensorFlow types are handled separately at the end as they do partial
  216. // match.
  217. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
  218. // NOLINTNEXTLINE
  219. #include "tf_types.def"
  220. .StartsWith("resource", TensorFlowTypes::RESOURCE)
  221. .StartsWith("variant", TensorFlowTypes::VARIANT)
  222. .Default(0);
  223. switch (typeKind)
  224. {
  225. default:
  226. return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
  227. #define HANDLE_TF_TYPE(tftype, enumerant, name) \
  228. case TensorFlowTypes::enumerant: \
  229. return tftype##Type::get(getContext());
  230. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
  231. // NOLINTNEXTLINE
  232. #include "tf_types.def"
  233. case TensorFlowTypes::RESOURCE:
  234. return ParseResourceType(parser, loc);
  235. case TensorFlowTypes::VARIANT:
  236. return ParseVariantType(parser, loc);
  237. }
  238. }
  239. namespace {
  240. template<typename TypeWithSubtype>
  241. Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser,
  242. Location loc)
  243. {
  244. // Default type without inferred subtypes.
  245. if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
  246. // Most types with subtypes have only one subtype.
  247. SmallVector<TensorType, 1> subtypes;
  248. do
  249. {
  250. TensorType tensor_ty;
  251. if (parser.parseType(tensor_ty)) return Type();
  252. subtypes.push_back(tensor_ty);
  253. } while (succeeded(parser.parseOptionalComma()));
  254. if (parser.parseGreater()) return Type();
  255. return TypeWithSubtype::getChecked(subtypes, context, loc);
  256. }
  257. } // anonymous namespace
  258. Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser,
  259. Location loc) const
  260. {
  261. return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
  262. }
  263. Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser,
  264. Location loc) const
  265. {
  266. return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
  267. }
  268. Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
  269. Attribute value, Type type,
  270. Location loc)
  271. {
  272. return builder.create<ConstOp>(loc, type, value);
  273. }
  274. #define GET_OP_CLASSES
  275. #include "tf_ops.cpp.inc"
  276. // Builds a constant op with the specified attribute `value`. The result
  277. // op's type is deduced from `value`; if `value` is of scalar type,
  278. // wraps it up with a tensor type of empty shape.
  279. // TODO(jpienaar): This one differs from the autogenerated one as it takes an
  280. // attribute but always creates an ElementsAttr internally.
  281. void ConstOp::build(OpBuilder& builder, OperationState& result,
  282. Attribute value)
  283. {
  284. ShapedType type;
  285. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  286. {
  287. return ConstOp::build(builder, result, elem_attr);
  288. }
  289. else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() || value.isa<IntegerAttr>())
  290. {
  291. // All TensorFlow types must be tensor types. In the build() method,
  292. // we want to provide more flexibility by allowing attributes of scalar
  293. // types. But we need to wrap it up with ElementsAttr to construct
  294. // valid TensorFlow constants.
  295. type = RankedTensorType::get(/*shape=*/ {}, value.getType());
  296. return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
  297. }
  298. // TODO(jpienaar): support other TensorFlow specific types.
  299. llvm_unreachable("unsupported attribute type for building tf.Const");
  300. }
  301. void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
  302. Attribute value)
  303. {
  304. // Handle the case where the type and value are already tensors.
  305. if (type.isa<TensorType>() && value.isa<ElementsAttr>())
  306. {
  307. result.addTypes(type);
  308. result.addAttribute("value", value);
  309. return;
  310. }
  311. // Otherwise, default to the attribute builder.
  312. ConstOp::build(builder, result, value);
  313. assert(type == result.types[0] && "type mismatch in construction");
  314. }
  315. LogicalResult ConstOp::inferReturnTypes(
  316. MLIRContext* context, Optional<Location> location, ValueRange operands,
  317. DictionaryAttr attributes, RegionRange regions,
  318. SmallVectorImpl<Type>& inferredReturnTypes)
  319. {
  320. auto value = attributes.get("value");
  321. if (!value) return emitOptionalError(location, "missing attribute 'value'");
  322. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  323. {
  324. inferredReturnTypes.assign({elem_attr.getType()});
  325. return success();
  326. }
  327. return emitOptionalError(location,
  328. "attribute 'value' failed to satisfy constraint: "
  329. "constant vector/tensor");
  330. }
  331. Region& WhileRegionOp::getLoopBody()
  332. {
  333. return body();
  334. }
  335. bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
  336. {
  337. // If the Op defining the value exists and the defining op is outside the
  338. // scope of this WhileRegion, then we can infer that its defined outside.
  339. // The defining Op is outside the scope of this WhileRegion if this
  340. // WhileRegionOp is not an ancestor of the defining op in the parent chain.
  341. Operation* def_op = value.getDefiningOp();
  342. return def_op && !getOperation()->isAncestor(def_op);
  343. }
  344. LogicalResult WhileRegionOp::moveOutOfLoop(
  345. llvm::ArrayRef<mlir::Operation*> ops)
  346. {
  347. // Move the hoisted value to just before the while.
  348. Operation* while_op = this->getOperation();
  349. for (auto op : ops) op->moveBefore(while_op);
  350. return success();
  351. }
  352. } // namespace TF
  353. } // namespace mlir
  354. static std::string get_mlir_value_uniq_id(const mlir::Value& value)
  355. {
  356. if (value.getLoc().isa<mlir::FileLineColLoc>())
  357. {
  358. mlir::FileLineColLoc floc = value.getLoc().cast<mlir::FileLineColLoc>();
  359. return floc.getFilename().str() + ":" + std::to_string(floc.getLine()) + ":" + std::to_string(floc.getColumn());
  360. }
  361. fprintf(stderr, "unhandled get_mlir_value_uniq_id\n");
  362. return std::string();
  363. }
  364. static std::string get_attr_s(const mlir::Attribute& attr)
  365. {
  366. std::string s;
  367. if (attr.isa<mlir::StringAttr>())
  368. {
  369. mlir::StringAttr a = attr.cast<mlir::StringAttr>();
  370. s = a.getValue().str();
  371. }
  372. return s;
  373. }
  374. static int get_attr_b(const mlir::Attribute& attr)
  375. {
  376. int i;
  377. if (attr.isa<mlir::BoolAttr>())
  378. {
  379. mlir::BoolAttr a = attr.cast<mlir::BoolAttr>();
  380. i = a.getValue() ? 1 : 0;
  381. }
  382. else
  383. {
  384. fprintf(stderr, "not BoolAttr\n");
  385. }
  386. return i;
  387. }
  388. static int get_attr_i(const mlir::Attribute& attr)
  389. {
  390. int i;
  391. if (attr.isa<mlir::IntegerAttr>())
  392. {
  393. mlir::IntegerAttr a = attr.cast<mlir::IntegerAttr>();
  394. i = (int)a.getInt();
  395. }
  396. else
  397. {
  398. fprintf(stderr, "not IntegerAttr\n");
  399. }
  400. return i;
  401. }
  402. static float get_attr_f(const mlir::Attribute& attr)
  403. {
  404. float f;
  405. if (attr.isa<mlir::FloatAttr>())
  406. {
  407. mlir::FloatAttr a = attr.cast<mlir::FloatAttr>();
  408. f = (float)a.getValueAsDouble();
  409. }
  410. else
  411. {
  412. fprintf(stderr, "not FloatAttr\n");
  413. }
  414. return f;
  415. }
  416. static std::vector<int> get_attr_ai(const mlir::Attribute& attr)
  417. {
  418. std::vector<int> v;
  419. if (attr.isa<mlir::ArrayAttr>())
  420. {
  421. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  422. const int array_size = a.getValue().size();
  423. v.resize(array_size);
  424. for (int j = 0; j < array_size; j++)
  425. {
  426. if (a[j].isa<mlir::IntegerAttr>())
  427. {
  428. int64_t ii = a[j].cast<mlir::IntegerAttr>().getInt();
  429. v[j] = std::max(std::min(ii, (int64_t)INT_MAX), (int64_t)INT_MIN);
  430. }
  431. }
  432. }
  433. else if (attr.isa<mlir::DenseIntElementsAttr>())
  434. {
  435. mlir::DenseIntElementsAttr ai = attr.cast<mlir::DenseIntElementsAttr>();
  436. for (auto ii : ai.getIntValues())
  437. {
  438. v.push_back(ii.getSExtValue());
  439. }
  440. }
  441. else
  442. {
  443. fprintf(stderr, "not ArrayAttr or DenseIntElementsAttr\n");
  444. }
  445. return v;
  446. }
  447. static std::vector<float> get_attr_af(const mlir::Attribute& attr)
  448. {
  449. std::vector<float> v;
  450. if (attr.isa<mlir::ArrayAttr>())
  451. {
  452. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  453. const int array_size = a.getValue().size();
  454. v.resize(array_size);
  455. for (int j = 0; j < array_size; j++)
  456. {
  457. if (a[j].isa<mlir::FloatAttr>())
  458. {
  459. double ff = a[j].cast<mlir::FloatAttr>().getValueAsDouble();
  460. v[j] = ff;
  461. }
  462. }
  463. }
  464. else if (attr.isa<mlir::DenseFPElementsAttr>())
  465. {
  466. mlir::DenseFPElementsAttr af = attr.cast<mlir::DenseFPElementsAttr>();
  467. for (auto ff : af.getFloatValues())
  468. {
  469. v.push_back(ff.convertToFloat());
  470. }
  471. }
  472. else
  473. {
  474. fprintf(stderr, "not ArrayAttr or DenseFPElementsAttr\n");
  475. }
  476. return v;
  477. }
  478. static std::string get_operation_attr_s(const mlir::Operation& _operation, const char* key)
  479. {
  480. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  481. mlir::Attribute attr = operation.getAttr(key);
  482. return get_attr_s(attr);
  483. }
  484. static int get_operation_attr_b(const mlir::Operation& _operation, const char* key)
  485. {
  486. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  487. mlir::Attribute attr = operation.getAttr(key);
  488. return get_attr_b(attr);
  489. }
  490. static int get_operation_attr_i(const mlir::Operation& _operation, const char* key)
  491. {
  492. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  493. mlir::Attribute attr = operation.getAttr(key);
  494. return get_attr_i(attr);
  495. }
  496. static float get_operation_attr_f(const mlir::Operation& _operation, const char* key)
  497. {
  498. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  499. mlir::Attribute attr = operation.getAttr(key);
  500. return get_attr_f(attr);
  501. }
  502. static std::vector<int> get_operation_attr_ai(const mlir::Operation& _operation, const char* key)
  503. {
  504. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  505. mlir::Attribute attr = operation.getAttr(key);
  506. return get_attr_ai(attr);
  507. }
  508. static std::vector<float> get_operation_attr_af(const mlir::Operation& _operation, const char* key)
  509. {
  510. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  511. mlir::Attribute attr = operation.getAttr(key);
  512. return get_attr_af(attr);
  513. }
  514. int main(int argc, char** argv)
  515. {
  516. const char* mlirpath = argv[1];
  517. const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
  518. const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
  519. mlir::registerDialect<mlir::StandardOpsDialect>();
  520. mlir::registerDialect<mlir::TF::TensorFlowDialect>();
  521. mlir::MLIRContext context;
  522. mlir::OwningModuleRef m = mlir::parseSourceFile(mlirpath, &context);
  523. // m->dump();
  524. mlir::FuncOp main_fn = m->lookupSymbol<mlir::FuncOp>("main");
  525. auto& bb = main_fn.getBlocks().front();
  526. // bb.dump();
  527. FILE* pp = fopen(ncnn_prototxt, "wb");
  528. FILE* bp = fopen(ncnn_modelbin, "wb");
  529. // node reference
  530. std::map<std::string, int> node_reference;
  531. // weight node and weight reshape node
  532. std::map<std::string, mlir::Attribute> weights;
  533. // weight node before BinaryOp
  534. std::map<std::string, mlir::Attribute> binaryop_weights;
  535. fprintf(pp, "7767517\n");
  536. const mlir::Block::OpListType& operations = bb.getOperations();
  537. int node_count = operations.size();
  538. // global definition line
  539. // [layer count] [blob count]
  540. std::set<std::string> blob_names;
  541. for (const mlir::Operation& _operation : operations)
  542. {
  543. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  544. std::string op = operation.getName().getStringRef().str();
  545. int num_input = (int)operation.getNumOperands();
  546. int num_output = (int)operation.getNumResults();
  547. if (op == "tf.Const")
  548. {
  549. // weight
  550. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  551. weights[output_name] = operation.getAttr("value");
  552. continue;
  553. }
  554. else
  555. {
  556. bool isBinaryOp = false;
  557. // TODO add more binaryop
  558. if (op == "tf.BiasAdd" || op == "tf.AddV2" || op == "tf.Sub" || op == "tf.Mul")
  559. {
  560. isBinaryOp = true;
  561. }
  562. if (isBinaryOp)
  563. {
  564. // check weights
  565. for (int j = 0; j < num_input; j++)
  566. {
  567. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(j));
  568. std::map<std::string, mlir::Attribute>::iterator it = weights.find(input_name);
  569. if (it != weights.end())
  570. {
  571. // binary op with weight, insert MemoryData layer and const blob
  572. binaryop_weights[input_name] = it->second;
  573. weights.erase(it);
  574. }
  575. }
  576. }
  577. }
  578. for (int j = 0; j < num_input; j++)
  579. {
  580. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(j));
  581. // check weight
  582. if (weights.find(input_name) != weights.end())
  583. {
  584. continue;
  585. }
  586. blob_names.insert(input_name);
  587. if (node_reference.find(input_name) == node_reference.end())
  588. {
  589. node_reference[input_name] = 1;
  590. }
  591. else
  592. {
  593. node_reference[input_name] = node_reference[input_name] + 1;
  594. }
  595. }
  596. for (int j = 0; j < num_output; j++)
  597. {
  598. std::string output_name = get_mlir_value_uniq_id(operation.getResult(j));
  599. blob_names.insert(output_name);
  600. }
  601. }
  602. // remove node_reference entry with reference equals to one
  603. int splitncnn_blob_count = 0;
  604. std::map<std::string, int>::iterator it = node_reference.begin();
  605. while (it != node_reference.end())
  606. {
  607. if (it->second == 1)
  608. {
  609. node_reference.erase(it++);
  610. }
  611. else
  612. {
  613. splitncnn_blob_count += it->second;
  614. // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second);
  615. ++it;
  616. }
  617. }
  618. fprintf(pp, "%lu %lu\n", node_count + node_reference.size() - weights.size(), blob_names.size() + splitncnn_blob_count);
  619. int internal_split = 0;
  620. // model op
  621. int g_opid = 0;
  622. for (const mlir::Operation& _operation : operations)
  623. {
  624. mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
  625. std::string op = operation.getName().getStringRef().str();
  626. int opid = g_opid++;
  627. int num_input = (int)operation.getNumOperands();
  628. int num_output = (int)operation.getNumResults();
  629. for (int i = 0; i < (int)operation.getNumOperands(); i++)
  630. {
  631. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(i));
  632. // check weight
  633. if (weights.find(input_name) != weights.end())
  634. {
  635. num_input--;
  636. }
  637. }
  638. if (op == "std.return")
  639. {
  640. fprintf(pp, "%-16s", "Noop");
  641. }
  642. else if (op == "tf.AddN")
  643. {
  644. fprintf(pp, "%-16s", "Eltwise");
  645. }
  646. else if (op == "tf.AddV2")
  647. {
  648. fprintf(pp, "%-16s", "BinaryOp");
  649. }
  650. else if (op == "tf.AvgPool")
  651. {
  652. fprintf(pp, "%-16s", "Pooling");
  653. }
  654. else if (op == "tf.BiasAdd")
  655. {
  656. fprintf(pp, "%-16s", "BinaryOp");
  657. }
  658. else if (op == "tf.ConcatV2")
  659. {
  660. fprintf(pp, "%-16s", "Concat");
  661. }
  662. else if (op == "tf.Const")
  663. {
  664. // check weight before BinaryOp
  665. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  666. if (binaryop_weights.find(output_name) != binaryop_weights.end())
  667. {
  668. fprintf(pp, "%-16s", "MemoryData");
  669. }
  670. else
  671. {
  672. continue;
  673. }
  674. }
  675. else if (op == "tf.Conv2D")
  676. {
  677. fprintf(pp, "%-16s", "Convolution");
  678. }
  679. else if (op == "tf.Conv2DBackpropInput")
  680. {
  681. fprintf(pp, "%-16s", "Deconvolution");
  682. }
  683. else if (op == "tf.DepthwiseConv2dNative")
  684. {
  685. fprintf(pp, "%-16s", "ConvolutionDepthWise");
  686. }
  687. else if (op == "tf.Identity")
  688. {
  689. fprintf(pp, "%-16s", "Noop");
  690. }
  691. else if (op == "tf.LeakyRelu")
  692. {
  693. fprintf(pp, "%-16s", "ReLU");
  694. }
  695. else if (op == "tf.MatMul")
  696. {
  697. fprintf(pp, "%-16s", "InnerProduct");
  698. }
  699. else if (op == "tf.MaxPool")
  700. {
  701. fprintf(pp, "%-16s", "Pooling");
  702. }
  703. else if (op == "tf.Mean")
  704. {
  705. std::string reduction_indices_name = get_mlir_value_uniq_id(operation.getOperand(1));
  706. const mlir::Attribute& R = weights[reduction_indices_name];
  707. std::vector<int> v = get_attr_ai(R);
  708. int keep_dims = get_operation_attr_b(operation, "keep_dims");
  709. if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
  710. {
  711. // global avg pooling style nhwc -> nc
  712. fprintf(pp, "%-16s", "Pooling");
  713. }
  714. else
  715. {
  716. fprintf(stderr, "tf.Mean is not global avg pooling\n");
  717. fprintf(pp, "%-16s", "Reduction");
  718. }
  719. }
  720. else if (op == "tf.Mul")
  721. {
  722. fprintf(pp, "%-16s", "BinaryOp");
  723. }
  724. else if (op == "tf.Pad")
  725. {
  726. fprintf(pp, "%-16s", "Padding");
  727. }
  728. else if (op == "tf.Placeholder")
  729. {
  730. fprintf(pp, "%-16s", "Input");
  731. }
  732. else if (op == "tf.Relu")
  733. {
  734. fprintf(pp, "%-16s", "ReLU");
  735. }
  736. else if (op == "tf.Relu6")
  737. {
  738. fprintf(pp, "%-16s", "Clip");
  739. }
  740. else if (op == "tf.Reshape")
  741. {
  742. fprintf(pp, "%-16s", "Reshape");
  743. }
  744. else if (op == "tf.ResizeNearestNeighbor")
  745. {
  746. fprintf(pp, "%-16s", "Interp");
  747. }
  748. else if (op == "tf.Sigmoid")
  749. {
  750. fprintf(pp, "%-16s", "Sigmoid");
  751. }
  752. else if (op == "tf.Softmax")
  753. {
  754. fprintf(pp, "%-16s", "Softmax");
  755. }
  756. else if (op == "tf.StridedSlice")
  757. {
  758. fprintf(pp, "%-16s", "Crop");
  759. }
  760. else if (op == "tf.Sub")
  761. {
  762. fprintf(pp, "%-16s", "BinaryOp");
  763. }
  764. else if (op == "tf.Tanh")
  765. {
  766. fprintf(pp, "%-16s", "TanH");
  767. }
  768. else
  769. {
  770. // TODO
  771. fprintf(stderr, "%s not supported yet!\n", op.c_str());
  772. fprintf(pp, "%-16s", op.c_str());
  773. }
  774. fprintf(pp, " op_%d %d %d", opid, num_input, num_output);
  775. for (int i = 0; i < (int)operation.getNumOperands(); i++)
  776. {
  777. std::string input_name = get_mlir_value_uniq_id(operation.getOperand(i));
  778. // check weight
  779. if (weights.find(input_name) != weights.end())
  780. {
  781. continue;
  782. }
  783. if (node_reference.find(input_name) != node_reference.end())
  784. {
  785. int refidx = node_reference[input_name] - 1;
  786. node_reference[input_name] = refidx;
  787. char splitsuffix[256];
  788. sprintf(splitsuffix, "_splitncnn_%d", refidx);
  789. input_name = input_name + splitsuffix;
  790. }
  791. fprintf(pp, " %s", input_name.c_str());
  792. }
  793. for (int i = 0; i < num_output; i++)
  794. {
  795. std::string output_name = get_mlir_value_uniq_id(operation.getResult(i));
  796. fprintf(pp, " %s", output_name.c_str());
  797. }
  798. if (op == "std.return")
  799. {
  800. }
  801. else if (op == "tf.AddN")
  802. {
  803. int op_type = 1;
  804. fprintf(pp, " 0=%d", op_type);
  805. }
  806. else if (op == "tf.AddV2")
  807. {
  808. int op_type = 0;
  809. fprintf(pp, " 0=%d", op_type);
  810. }
  811. else if (op == "tf.AvgPool")
  812. {
  813. std::vector<int> ksize = get_operation_attr_ai(operation, "ksize");
  814. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  815. std::string padding = get_operation_attr_s(operation, "padding");
  816. if (ksize.size() == 4)
  817. {
  818. fprintf(pp, " 1=%d", ksize[2]);
  819. fprintf(pp, " 11=%d", ksize[1]);
  820. }
  821. if (strides.size() == 4)
  822. {
  823. fprintf(pp, " 2=%d", strides[2]);
  824. fprintf(pp, " 12=%d", strides[1]);
  825. }
  826. int pad_mode = 1;
  827. if (padding == "VALID")
  828. {
  829. pad_mode = 1;
  830. }
  831. else if (padding == "SAME")
  832. {
  833. pad_mode = 2;
  834. }
  835. fprintf(pp, " 5=%d", pad_mode);
  836. }
  837. else if (op == "tf.ConcatV2")
  838. {
  839. std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(operation.getNumOperands() - 1));
  840. const mlir::Attribute& A = weights[axis_name];
  841. int axis = get_attr_ai(A)[0];
  842. // axis nhc to nhw
  843. // axis nhwc to nchw
  844. int dims = operation.getOperand(0).getType().cast<mlir::RankedTensorType>().getShape().size();
  845. if (dims == 2 && axis == 1)
  846. {
  847. axis = 0;
  848. }
  849. if (dims == 3 && axis == 1)
  850. {
  851. axis = 1;
  852. }
  853. if (dims == 3 && axis == 2)
  854. {
  855. axis = 0;
  856. }
  857. if (dims == 4 && axis == 1)
  858. {
  859. axis = 1;
  860. }
  861. if (dims == 4 && axis == 2)
  862. {
  863. axis = 2;
  864. }
  865. if (dims == 4 && axis == 3)
  866. {
  867. axis = 0;
  868. }
  869. fprintf(pp, " 0=%d", axis);
  870. }
  871. else if (op == "tf.Const")
  872. {
  873. // check weight before BinaryOp
  874. std::string output_name = get_mlir_value_uniq_id(operation.getResult(0));
  875. if (binaryop_weights.find(output_name) != binaryop_weights.end())
  876. {
  877. const mlir::Attribute& M = binaryop_weights[output_name];
  878. llvm::ArrayRef<int64_t> shape = M.getType().cast<mlir::RankedTensorType>().getShape();
  879. // c wc hwc
  880. if (shape.size() == 0)
  881. {
  882. // scalar
  883. fprintf(pp, " 0=1");
  884. }
  885. else if (shape.size() == 1)
  886. {
  887. fprintf(pp, " 0=%d", (int)shape[0]);
  888. }
  889. else if (shape.size() == 2)
  890. {
  891. fprintf(pp, " 0=%d", (int)shape[1]);
  892. fprintf(pp, " 1=%d", (int)shape[0]);
  893. }
  894. else if (shape.size() == 3)
  895. {
  896. fprintf(pp, " 0=%d", (int)shape[1]);
  897. fprintf(pp, " 1=%d", (int)shape[0]);
  898. fprintf(pp, " 2=%d", (int)shape[2]);
  899. }
  900. std::vector<float> v = get_attr_af(M);
  901. if (shape.size() != 3)
  902. {
  903. fwrite(v.data(), sizeof(float), v.size(), bp);
  904. }
  905. else
  906. {
  907. int w = (int)shape[1];
  908. int h = (int)shape[0];
  909. int c = (int)shape[2];
  910. float tmp;
  911. // h-w-c to c-h-w
  912. for (int p = 0; p < c; p++)
  913. {
  914. for (int i = 0; i < h; i++)
  915. {
  916. for (int j = 0; j < w; j++)
  917. {
  918. tmp = v[i * w * c + j * c + p];
  919. fwrite(&tmp, sizeof(float), 1, bp);
  920. }
  921. }
  922. }
  923. }
  924. }
  925. }
  926. else if (op == "tf.Conv2D")
  927. {
  928. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  929. const mlir::Attribute& W = weights[weight_name];
  930. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  931. // assert(shape.size() == 4)
  932. // kh-kw-inch-outch
  933. int kernel_size_h = shape[0];
  934. int kernel_size_w = shape[1];
  935. int num_input = shape[2];
  936. int num_output = shape[3];
  937. int weight_data_size = kernel_size_h * kernel_size_w * num_input * num_output;
  938. fprintf(pp, " 0=%d", num_output);
  939. fprintf(pp, " 1=%d", kernel_size_w);
  940. fprintf(pp, " 11=%d", kernel_size_h);
  941. fprintf(pp, " 6=%d", weight_data_size);
  942. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  943. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  944. std::string padding = get_operation_attr_s(operation, "padding");
  945. if (dilations.size() == 4)
  946. {
  947. fprintf(pp, " 2=%d", dilations[2]);
  948. fprintf(pp, " 12=%d", dilations[1]);
  949. }
  950. if (strides.size() == 4)
  951. {
  952. fprintf(pp, " 3=%d", strides[2]);
  953. fprintf(pp, " 13=%d", strides[1]);
  954. }
  955. if (padding == "EXPLICIT")
  956. {
  957. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  958. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  959. fprintf(pp, " 4=%d", explicit_paddings[4]);
  960. fprintf(pp, " 15=%d", explicit_paddings[5]);
  961. fprintf(pp, " 14=%d", explicit_paddings[2]);
  962. fprintf(pp, " 16=%d", explicit_paddings[3]);
  963. }
  964. else if (padding == "VALID")
  965. {
  966. fprintf(pp, " 4=%d", 0);
  967. }
  968. else if (padding == "SAME")
  969. {
  970. fprintf(pp, " 4=%d", -233);
  971. }
  972. std::vector<float> v = get_attr_af(W);
  973. // reorder h-w-i-o to o-i-h-w
  974. {
  975. int quantize_tag = 0;
  976. fwrite(&quantize_tag, sizeof(int), 1, bp);
  977. float tmp;
  978. for (int p = 0; p < num_output; p++)
  979. {
  980. for (int q = 0; q < num_input; q++)
  981. {
  982. for (int i = 0; i < kernel_size_h; i++)
  983. {
  984. for (int j = 0; j < kernel_size_w; j++)
  985. {
  986. tmp = v[i * kernel_size_w * num_input * num_output + j * num_input * num_output + q * num_output + p];
  987. fwrite(&tmp, sizeof(float), 1, bp);
  988. }
  989. }
  990. }
  991. }
  992. }
  993. }
  994. else if (op == "tf.Conv2DBackpropInput")
  995. {
  996. std::string output_shape_name = get_mlir_value_uniq_id(operation.getOperand(0));
  997. const std::vector<int> output_shape = get_attr_ai(weights[output_shape_name]);
  998. // assert(output_shape.size() == 4)
  999. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1000. const mlir::Attribute& W = weights[weight_name];
  1001. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  1002. // assert(shape.size() == 4)
  1003. // kh-kw-outch-inch
  1004. int kernel_size_h = shape[0];
  1005. int kernel_size_w = shape[1];
  1006. int num_output = shape[2];
  1007. int num_input = shape[3];
  1008. int weight_data_size = kernel_size_h * kernel_size_w * num_input * num_output;
  1009. fprintf(pp, " 0=%d", num_output);
  1010. fprintf(pp, " 1=%d", kernel_size_w);
  1011. fprintf(pp, " 11=%d", kernel_size_h);
  1012. fprintf(pp, " 6=%d", weight_data_size);
  1013. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  1014. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  1015. std::string padding = get_operation_attr_s(operation, "padding");
  1016. if (dilations.size() == 4)
  1017. {
  1018. fprintf(pp, " 2=%d", dilations[2]);
  1019. fprintf(pp, " 12=%d", dilations[1]);
  1020. }
  1021. if (strides.size() == 4)
  1022. {
  1023. fprintf(pp, " 3=%d", strides[2]);
  1024. fprintf(pp, " 13=%d", strides[1]);
  1025. }
  1026. if (padding == "EXPLICIT")
  1027. {
  1028. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  1029. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  1030. fprintf(pp, " 4=%d", explicit_paddings[4]);
  1031. fprintf(pp, " 15=%d", explicit_paddings[5]);
  1032. fprintf(pp, " 14=%d", explicit_paddings[2]);
  1033. fprintf(pp, " 16=%d", explicit_paddings[3]);
  1034. }
  1035. else if (padding == "VALID")
  1036. {
  1037. fprintf(pp, " 4=%d", 0);
  1038. }
  1039. else if (padding == "SAME")
  1040. {
  1041. fprintf(pp, " 4=%d", -233);
  1042. fprintf(pp, " 20=%d", output_shape[2]);
  1043. fprintf(pp, " 21=%d", output_shape[1]);
  1044. }
  1045. std::vector<float> v = get_attr_af(W);
  1046. // reorder h-w-o-i to o-i-h-w
  1047. {
  1048. int quantize_tag = 0;
  1049. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1050. float tmp;
  1051. for (int p = 0; p < num_output; p++)
  1052. {
  1053. for (int q = 0; q < num_input; q++)
  1054. {
  1055. for (int i = 0; i < kernel_size_h; i++)
  1056. {
  1057. for (int j = 0; j < kernel_size_w; j++)
  1058. {
  1059. tmp = v[i * kernel_size_w * num_output * num_input + j * num_output * num_input + p * num_input + q];
  1060. fwrite(&tmp, sizeof(float), 1, bp);
  1061. }
  1062. }
  1063. }
  1064. }
  1065. }
  1066. }
  1067. else if (op == "tf.DepthwiseConv2dNative")
  1068. {
  1069. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1070. const mlir::Attribute& W = weights[weight_name];
  1071. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  1072. // assert(shape.size() == 4)
  1073. // kh-kw-inch-cm
  1074. int kernel_size_h = shape[0];
  1075. int kernel_size_w = shape[1];
  1076. int num_input = shape[2];
  1077. int channel_multiplier = shape[3];
  1078. int num_output = num_input * channel_multiplier;
  1079. int group = num_input;
  1080. int weight_data_size = kernel_size_h * kernel_size_w * num_input * channel_multiplier;
  1081. fprintf(pp, " 0=%d", num_output);
  1082. fprintf(pp, " 1=%d", kernel_size_w);
  1083. fprintf(pp, " 11=%d", kernel_size_h);
  1084. fprintf(pp, " 6=%d", weight_data_size);
  1085. fprintf(pp, " 7=%d", group);
  1086. std::vector<int> dilations = get_operation_attr_ai(operation, "dilations");
  1087. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  1088. std::string padding = get_operation_attr_s(operation, "padding");
  1089. if (dilations.size() == 4)
  1090. {
  1091. fprintf(pp, " 2=%d", dilations[2]);
  1092. fprintf(pp, " 12=%d", dilations[1]);
  1093. }
  1094. if (strides.size() == 4)
  1095. {
  1096. fprintf(pp, " 3=%d", strides[2]);
  1097. fprintf(pp, " 13=%d", strides[1]);
  1098. }
  1099. if (padding == "EXPLICIT")
  1100. {
  1101. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  1102. std::vector<int> explicit_paddings = get_operation_attr_ai(operation, "explicit_paddings");
  1103. fprintf(pp, " 4=%d", explicit_paddings[4]);
  1104. fprintf(pp, " 15=%d", explicit_paddings[5]);
  1105. fprintf(pp, " 14=%d", explicit_paddings[2]);
  1106. fprintf(pp, " 16=%d", explicit_paddings[3]);
  1107. }
  1108. else if (padding == "VALID")
  1109. {
  1110. fprintf(pp, " 4=%d", 0);
  1111. }
  1112. else if (padding == "SAME")
  1113. {
  1114. fprintf(pp, " 4=%d", -233);
  1115. }
  1116. std::vector<float> v = get_attr_af(W);
  1117. // reorder h-w-i-cm to i-cm-h-w
  1118. {
  1119. int quantize_tag = 0;
  1120. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1121. float tmp;
  1122. for (int p = 0; p < num_input; p++)
  1123. {
  1124. for (int q = 0; q < channel_multiplier; q++)
  1125. {
  1126. for (int i = 0; i < kernel_size_h; i++)
  1127. {
  1128. for (int j = 0; j < kernel_size_w; j++)
  1129. {
  1130. tmp = v[i * kernel_size_w * channel_multiplier * num_input + j * channel_multiplier * num_input + p * channel_multiplier + q];
  1131. fwrite(&tmp, sizeof(float), 1, bp);
  1132. }
  1133. }
  1134. }
  1135. }
  1136. }
  1137. }
  1138. else if (op == "tf.Identity")
  1139. {
  1140. }
  1141. else if (op == "tf.LeakyRelu")
  1142. {
  1143. float alpha = get_operation_attr_f(operation, "alpha");
  1144. fprintf(pp, " 0=%e", alpha);
  1145. }
  1146. else if (op == "tf.MatMul")
  1147. {
  1148. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1149. const mlir::Attribute& W = weights[weight_name];
  1150. llvm::ArrayRef<int64_t> shape = W.getType().cast<mlir::RankedTensorType>().getShape();
  1151. // assert(shape.size() == 2)
  1152. // inch-outch
  1153. int num_input = shape[0];
  1154. int num_output = shape[1];
  1155. int weight_data_size = shape[0] * shape[1];
  1156. fprintf(pp, " 0=%d", num_output);
  1157. fprintf(pp, " 2=%d", weight_data_size);
  1158. std::vector<float> v = get_attr_af(W);
  1159. // reorder i-o to o-i
  1160. {
  1161. int quantize_tag = 0;
  1162. fwrite(&quantize_tag, sizeof(int), 1, bp);
  1163. float tmp;
  1164. for (int p = 0; p < num_output; p++)
  1165. {
  1166. for (int q = 0; q < num_input; q++)
  1167. {
  1168. tmp = v[q * num_output + p];
  1169. fwrite(&tmp, sizeof(float), 1, bp);
  1170. }
  1171. }
  1172. }
  1173. }
  1174. else if (op == "tf.MaxPool")
  1175. {
  1176. std::vector<int> ksize = get_operation_attr_ai(operation, "ksize");
  1177. std::vector<int> strides = get_operation_attr_ai(operation, "strides");
  1178. std::string padding = get_operation_attr_s(operation, "padding");
  1179. if (ksize.size() == 4)
  1180. {
  1181. fprintf(pp, " 1=%d", ksize[2]);
  1182. fprintf(pp, " 11=%d", ksize[1]);
  1183. }
  1184. if (strides.size() == 4)
  1185. {
  1186. fprintf(pp, " 2=%d", strides[2]);
  1187. fprintf(pp, " 12=%d", strides[1]);
  1188. }
  1189. int pad_mode = 1;
  1190. if (padding == "VALID")
  1191. {
  1192. pad_mode = 1;
  1193. }
  1194. else if (padding == "SAME")
  1195. {
  1196. pad_mode = 2;
  1197. }
  1198. fprintf(pp, " 5=%d", pad_mode);
  1199. }
  1200. else if (op == "tf.Mean")
  1201. {
  1202. std::string reduction_indices_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1203. const mlir::Attribute& R = weights[reduction_indices_name];
  1204. std::vector<int> v = get_attr_ai(R);
  1205. int keep_dims = get_operation_attr_b(operation, "keep_dims");
  1206. if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
  1207. {
  1208. // global avg pooling style nhwc -> nc
  1209. int pool = 1;
  1210. int global_pool = 1;
  1211. fprintf(pp, " 0=%d", pool);
  1212. fprintf(pp, " 4=%d", global_pool);
  1213. }
  1214. else
  1215. {
  1216. // TODO
  1217. }
  1218. }
  1219. else if (op == "tf.Mul")
  1220. {
  1221. int op_type = 2;
  1222. fprintf(pp, " 0=%d", op_type);
  1223. }
  1224. else if (op == "tf.Pad")
  1225. {
  1226. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1227. const mlir::Attribute& P = weights[weight_name];
  1228. std::vector<int> v = get_attr_ai(P);
  1229. // nhwc = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
  1230. fprintf(pp, " 0=%d", v[2]);
  1231. fprintf(pp, " 1=%d", v[3]);
  1232. fprintf(pp, " 2=%d", v[4]);
  1233. fprintf(pp, " 3=%d", v[5]);
  1234. }
  1235. else if (op == "tf.Placeholder")
  1236. {
  1237. }
  1238. else if (op == "tf.Relu")
  1239. {
  1240. }
  1241. else if (op == "tf.Relu6")
  1242. {
  1243. float min = 0.f;
  1244. float max = 6.f;
  1245. fprintf(pp, " 0=%e", min);
  1246. fprintf(pp, " 1=%e", max);
  1247. }
  1248. else if (op == "tf.Reshape")
  1249. {
  1250. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1251. const mlir::Attribute& S = weights[weight_name];
  1252. std::vector<int> v = get_attr_ai(S);
  1253. int size = v.size();
  1254. // n h w c
  1255. // n h c
  1256. // n c
  1257. if (size == 4)
  1258. {
  1259. fprintf(pp, " 0=%d 1=%d 2=%d", v[2], v[1], v[3]);
  1260. }
  1261. if (size == 3)
  1262. {
  1263. fprintf(pp, " 0=%d 1=%d 2=-233", v[1], v[2]);
  1264. }
  1265. if (size == 2)
  1266. {
  1267. fprintf(pp, " 0=%d 1=-233 2=-233", v[1]);
  1268. }
  1269. // FIXME may not always be the case
  1270. fprintf(pp, " 3=1");
  1271. }
  1272. else if (op == "tf.ResizeNearestNeighbor")
  1273. {
  1274. std::string weight_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1275. const mlir::Attribute& P = weights[weight_name];
  1276. std::vector<int> size = get_attr_ai(P);
  1277. int align_corners = get_operation_attr_b(operation, "align_corners");
  1278. int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers");
  1279. if (!(align_corners == 0 && half_pixel_centers == 1))
  1280. {
  1281. fprintf(stderr, "Unsupported ResizeNearestNeighbor align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers);
  1282. }
  1283. fprintf(pp, " 0=1"); // nearest
  1284. fprintf(pp, " 3=%d 4=%d", size[1], size[0]);
  1285. }
  1286. else if (op == "tf.Sigmoid")
  1287. {
  1288. }
  1289. else if (op == "tf.Softmax")
  1290. {
  1291. }
  1292. else if (op == "tf.StridedSlice")
  1293. {
  1294. std::string begin_name = get_mlir_value_uniq_id(operation.getOperand(1));
  1295. std::string end_name = get_mlir_value_uniq_id(operation.getOperand(2));
  1296. std::string strides_name = get_mlir_value_uniq_id(operation.getOperand(3));
  1297. const mlir::Attribute& B = weights[begin_name];
  1298. const mlir::Attribute& E = weights[end_name];
  1299. const mlir::Attribute& S = weights[strides_name];
  1300. std::vector<int> begin = get_attr_ai(B);
  1301. std::vector<int> end = get_attr_ai(E);
  1302. std::vector<int> strides = get_attr_ai(S);
  1303. int begin_mask = get_operation_attr_i(operation, "begin_mask");
  1304. int end_mask = get_operation_attr_i(operation, "end_mask");
  1305. int ellipsis_mask = get_operation_attr_i(operation, "ellipsis_mask");
  1306. int new_axis_mask = get_operation_attr_i(operation, "new_axis_mask");
  1307. int shrink_axis_mask = get_operation_attr_i(operation, "shrink_axis_mask");
  1308. int dims = strides.size();
  1309. // assert strides == 1
  1310. for (int i = 0; i < dims; i++)
  1311. {
  1312. if (strides[i] != 1)
  1313. fprintf(stderr, "Unsupported StridedSlice strides !\n");
  1314. }
  1315. for (int i = 0; i < dims; i++)
  1316. {
  1317. // TODO strides[i] < 0
  1318. if (begin_mask & (1 << i))
  1319. {
  1320. begin[i] = 0;
  1321. }
  1322. if (end_mask & (1 << i))
  1323. {
  1324. end[i] = -233;
  1325. }
  1326. if (ellipsis_mask & (1 << i))
  1327. {
  1328. begin[i] = 0;
  1329. end[i] = -233;
  1330. }
  1331. }
  1332. if (new_axis_mask)
  1333. {
  1334. fprintf(stderr, "Unsupported StridedSlice new_axis_mask !\n");
  1335. }
  1336. if (shrink_axis_mask)
  1337. {
  1338. fprintf(stderr, "Unsupported StridedSlice shrink_axis_mask !\n");
  1339. }
  1340. // n h w c
  1341. // n h c
  1342. // n c
  1343. if (dims == 4)
  1344. {
  1345. fprintf(pp, " -23309=3,%d,%d,%d", begin[3], begin[1], begin[2]);
  1346. fprintf(pp, " -23310=3,%d,%d,%d", end[3], end[1], end[2]);
  1347. }
  1348. if (dims == 3)
  1349. {
  1350. fprintf(pp, " -23309=2,%d,%d", begin[2], begin[1]);
  1351. fprintf(pp, " -23310=2,%d,%d", end[2], end[1]);
  1352. }
  1353. if (dims == 2)
  1354. {
  1355. fprintf(pp, " -23309=1,%d", begin[1]);
  1356. fprintf(pp, " -23310=1,%d", end[1]);
  1357. }
  1358. }
  1359. else if (op == "tf.Sub")
  1360. {
  1361. int op_type = 1;
  1362. fprintf(pp, " 0=%d", op_type);
  1363. }
  1364. else if (op == "tf.Tanh")
  1365. {
  1366. }
  1367. #if 0
  1368. for (const mlir::NamedAttribute& attr : operation.getAttrs())
  1369. {
  1370. const mlir::Identifier& identifier = attr.first;
  1371. const mlir::Attribute& attr = attr.second;
  1372. fprintf(pp, " %s=", identifier.c_str());
  1373. if (attr.isa<mlir::AffineMapAttr>())
  1374. {
  1375. fprintf(pp, "AffineMap");
  1376. }
  1377. if (attr.isa<mlir::ArrayAttr>())
  1378. {
  1379. // fprintf(pp, "Array");
  1380. mlir::ArrayAttr a = attr.cast<mlir::ArrayAttr>();
  1381. int array_size = a.getValue().size();
  1382. for (int t=0; t<array_size; t++)
  1383. {
  1384. if (a[t].isa<mlir::IntegerAttr>())
  1385. {
  1386. int64_t ii = a[t].cast<mlir::IntegerAttr>().getInt();
  1387. fprintf(pp, "%lld,", ii);
  1388. }
  1389. }
  1390. }
  1391. if (attr.isa<mlir::BoolAttr>())
  1392. {
  1393. // fprintf(pp, "Bool");
  1394. mlir::BoolAttr a = attr.cast<mlir::BoolAttr>();
  1395. fprintf(pp, "%d", a.getValue() ? 1 : 0);
  1396. }
  1397. if (attr.isa<mlir::DictionaryAttr>())
  1398. {
  1399. fprintf(pp, "Dictionary");
  1400. }
  1401. if (attr.isa<mlir::FloatAttr>())
  1402. {
  1403. fprintf(pp, "Float");
  1404. }
  1405. if (attr.isa<mlir::IntegerAttr>())
  1406. {
  1407. fprintf(pp, "Integer");
  1408. }
  1409. if (attr.isa<mlir::IntegerSetAttr>())
  1410. {
  1411. fprintf(pp, "IntegerSet");
  1412. }
  1413. if (attr.isa<mlir::OpaqueAttr>())
  1414. {
  1415. fprintf(pp, "Opaque");
  1416. }
  1417. if (attr.isa<mlir::StringAttr>())
  1418. {
  1419. // fprintf(pp, "String");
  1420. mlir::StringAttr s = attr.cast<mlir::StringAttr>();
  1421. fprintf(pp, "%s", s.getValue().empty() ? "" : s.getValue().data());
  1422. }
  1423. if (attr.isa<mlir::SymbolRefAttr>())
  1424. {
  1425. fprintf(pp, "SymbolRef");
  1426. }
  1427. if (attr.isa<mlir::FlatSymbolRefAttr>())
  1428. {
  1429. fprintf(pp, "FlatSymbolRef");
  1430. }
  1431. if (attr.isa<mlir::TypeAttr>())
  1432. {
  1433. fprintf(pp, "Type");
  1434. }
  1435. if (attr.isa<mlir::UnitAttr>())
  1436. {
  1437. fprintf(pp, "Unit");
  1438. }
  1439. if (attr.isa<mlir::ElementsAttr>())
  1440. {
  1441. fprintf(pp, "Elements");
  1442. }
  1443. if (attr.isa<mlir::DenseElementsAttr>())
  1444. {
  1445. fprintf(pp, "DenseElements");
  1446. }
  1447. if (attr.isa<mlir::DenseFPElementsAttr>())
  1448. {
  1449. fprintf(pp, "DenseFPElements");
  1450. }
  1451. if (attr.isa<mlir::DenseIntElementsAttr>())
  1452. {
  1453. fprintf(pp, "DenseIntElements");
  1454. }
  1455. if (attr.isa<mlir::OpaqueElementsAttr>())
  1456. {
  1457. fprintf(pp, "OpaqueElements");
  1458. }
  1459. if (attr.isa<mlir::SparseElementsAttr>())
  1460. {
  1461. fprintf(pp, "SparseElements");
  1462. }
  1463. if (attr.isa<mlir::SplatElementsAttr>())
  1464. {
  1465. fprintf(pp, "SplatElements");
  1466. }
  1467. }
  1468. #endif
  1469. fprintf(pp, "\n");
  1470. for (int j = 0; j < num_output; j++)
  1471. {
  1472. std::string output_name = get_mlir_value_uniq_id(operation.getResult(j));
  1473. if (node_reference.find(output_name) != node_reference.end())
  1474. {
  1475. int refcount = node_reference[output_name];
  1476. if (refcount > 1)
  1477. {
  1478. char splitname[256];
  1479. sprintf(splitname, "splitncnn_%d", internal_split);
  1480. fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
  1481. fprintf(pp, " %s", output_name.c_str());
  1482. for (int k = 0; k < refcount; k++)
  1483. {
  1484. fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
  1485. }
  1486. fprintf(pp, "\n");
  1487. internal_split++;
  1488. }
  1489. }
  1490. }
  1491. }
  1492. fclose(pp);
  1493. fclose(bp);
  1494. return 0;
  1495. }