You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_dialect.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "tf_dialect.h"
  15. #include <mlir/Dialect/Traits.h>
  16. #include <mlir/IR/Attributes.h>
  17. #include <mlir/IR/Builders.h>
  18. #include <mlir/IR/Dialect.h>
  19. #include <mlir/IR/DialectImplementation.h>
  20. #include <mlir/IR/Function.h>
  21. #include <mlir/IR/Location.h>
  22. #include <mlir/IR/Matchers.h>
  23. #include <mlir/IR/MLIRContext.h>
  24. #include <mlir/IR/OpDefinition.h>
  25. #include <mlir/IR/OpImplementation.h>
  26. #include <mlir/IR/Operation.h>
  27. #include <mlir/IR/OperationSupport.h>
  28. #include <mlir/IR/PatternMatch.h>
  29. #include <mlir/IR/StandardTypes.h>
  30. #include <mlir/IR/TypeUtilities.h>
  31. #include <mlir/IR/Types.h>
  32. #include <mlir/IR/Value.h>
  33. #include <mlir/IR/Verifier.h>
  34. #include <mlir/Interfaces/CallInterfaces.h>
  35. #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
  36. #include <mlir/Interfaces/InferTypeOpInterface.h>
  37. #include <mlir/Interfaces/LoopLikeInterface.h>
  38. #include <mlir/Interfaces/SideEffectInterfaces.h>
  39. #include <mlir/Parser.h>
  40. #include <mlir/Support/LogicalResult.h>
  41. #include <mlir/Transforms/InliningUtils.h>
  42. #include "tf_attributes.h"
  43. #include "tf_side_effects.h"
  44. #include "tf_traits.h"
  45. namespace mlir {
  46. static LogicalResult Verify(...)
  47. {
  48. return success();
  49. }
  50. static LogicalResult VerifyPartitionedCall(...)
  51. {
  52. return success();
  53. }
  54. static LogicalResult VerifyStridedSliceBase(...)
  55. {
  56. return success();
  57. }
  58. static LogicalResult VerifyUnsortedSegmentReduction(...)
  59. {
  60. return success();
  61. }
  62. namespace TF {
  63. TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
  64. : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
  65. {
  66. addOperations<
  67. #define GET_OP_LIST
  68. #include "tf_all_ops.cc.inc"
  69. >();
  70. addTypes<
  71. #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
  72. #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
  73. #include "tf_types.def"
  74. >();
  75. // addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
  76. // TFConstantFoldInterface>();
  77. addAttributes<ShapeAttr, FuncAttr>();
  78. // Support unknown operations because not all TensorFlow operations are
  79. // registered.
  80. allowUnknownOperations();
  81. // for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
  82. // hook(*this);
  83. // }
  84. }
  85. namespace {
  86. ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
  87. {
  88. auto emit_error = [&, spec]() {
  89. emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
  90. return nullptr;
  91. };
  92. if (!spec.consume_front("shape<")) return emit_error();
  93. if (spec.consume_front("*>"))
  94. return mlir::TF::ShapeAttr::get(context, llvm::None);
  95. SmallVector<int64_t, 4> shape;
  96. while (!spec.consume_front(">"))
  97. {
  98. int64_t dim;
  99. if (spec.consume_front("?"))
  100. dim = -1;
  101. else if (spec.consumeInteger(10, dim) || dim < 0)
  102. return emit_error();
  103. spec.consume_front("x");
  104. shape.push_back(dim);
  105. }
  106. return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
  107. }
  108. // Parses a #tf.func attribute of the following format:
  109. //
  110. // #tf.func<@symbol, {attr = "value"}>
  111. //
  112. // where the first element is a SymbolRefAttr and the second element is a
  113. // DictionaryAttr.
  114. FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
  115. {
  116. auto emit_error = [&, spec]() {
  117. emitError(loc, "invalid TensorFlow func attribute: ") << spec;
  118. return nullptr;
  119. };
  120. if (!spec.consume_front("func<")) return emit_error();
  121. size_t func_name_num_read = 0;
  122. Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
  123. if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
  124. return emit_error();
  125. spec = spec.drop_front(func_name_num_read);
  126. if (!spec.consume_front(", ")) return emit_error();
  127. size_t func_attrs_num_read = 0;
  128. Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
  129. if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
  130. return emit_error();
  131. spec = spec.drop_front(func_attrs_num_read);
  132. if (!spec.consume_front(">")) return emit_error();
  133. return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
  134. func_attrs_attr.cast<DictionaryAttr>());
  135. }
  136. } // namespace
  137. Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
  138. Type type) const
  139. {
  140. auto spec = parser.getFullSymbolSpec();
  141. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  142. if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
  143. if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
  144. return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
  145. }
  146. // Parses a type registered to this dialect.
  147. Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
  148. {
  149. StringRef data;
  150. if (parser.parseKeyword(&data)) return Type();
  151. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  152. #define HANDLE_TF_TYPE(tftype, enumerant, name) \
  153. if (data == name) return tftype##Type::get(getContext());
  154. // Custom TensorFlow types are handled separately at the end as they do partial
  155. // match.
  156. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
  157. // NOLINTNEXTLINE
  158. #include "tf_types.def"
  159. if (data.startswith("resource")) return ParseResourceType(parser, loc);
  160. if (data.startswith("variant")) return ParseVariantType(parser, loc);
  161. return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
  162. }
  163. namespace {
  164. template<typename TypeWithSubtype>
  165. Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser,
  166. Location loc)
  167. {
  168. // Default type without inferred subtypes.
  169. if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
  170. // Most types with subtypes have only one subtype.
  171. SmallVector<TensorType, 1> subtypes;
  172. do
  173. {
  174. TensorType tensor_ty;
  175. if (parser.parseType(tensor_ty)) return Type();
  176. subtypes.push_back(tensor_ty);
  177. } while (succeeded(parser.parseOptionalComma()));
  178. if (parser.parseGreater()) return Type();
  179. return TypeWithSubtype::getChecked(subtypes, context, loc);
  180. }
  181. } // anonymous namespace
  182. Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser,
  183. Location loc) const
  184. {
  185. return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
  186. }
  187. Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser,
  188. Location loc) const
  189. {
  190. return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
  191. }
  192. Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
  193. Attribute value, Type type,
  194. Location loc)
  195. {
  196. return builder.create<ConstOp>(loc, type, value);
  197. }
  198. // Builds a constant op with the specified attribute `value`. The result
  199. // op's type is deduced from `value`; if `value` is of scalar type,
  200. // wraps it up with a tensor type of empty shape.
  201. // TODO(jpienaar): This one differs from the autogenerated one as it takes an
  202. // attribute but always creates an ElementsAttr internally.
  203. void ConstOp::build(OpBuilder& builder, OperationState& result,
  204. Attribute value)
  205. {
  206. ShapedType type;
  207. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  208. {
  209. return ConstOp::build(builder, result, elem_attr);
  210. }
  211. else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
  212. {
  213. // All TensorFlow types must be tensor types. In the build() method,
  214. // we want to provide more flexibility by allowing attributes of scalar
  215. // types. But we need to wrap it up with ElementsAttr to construct
  216. // valid TensorFlow constants.
  217. type = RankedTensorType::get(/*shape=*/ {}, value.getType());
  218. return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
  219. }
  220. // TODO(jpienaar): support other TensorFlow specific types.
  221. llvm_unreachable("unsupported attribute type for building tf.Const");
  222. }
  223. void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
  224. Attribute value)
  225. {
  226. // Handle the case where the type and value are already tensors.
  227. if (type.isa<TensorType>() && value.isa<ElementsAttr>())
  228. {
  229. result.addTypes(type);
  230. result.addAttribute("value", value);
  231. return;
  232. }
  233. // Otherwise, default to the attribute builder.
  234. ConstOp::build(builder, result, value);
  235. assert(type == result.types[0] && "type mismatch in construction");
  236. }
  237. LogicalResult ConstOp::inferReturnTypes(
  238. MLIRContext* context, Optional<Location> location, ValueRange operands,
  239. DictionaryAttr attributes, RegionRange regions,
  240. SmallVectorImpl<Type>& inferredReturnTypes)
  241. {
  242. auto value = attributes.get("value");
  243. if (!value) return emitOptionalError(location, "missing attribute 'value'");
  244. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  245. {
  246. inferredReturnTypes.assign({elem_attr.getType()});
  247. return success();
  248. }
  249. return emitOptionalError(location,
  250. "attribute 'value' failed to satisfy constraint: "
  251. "constant vector/tensor");
  252. }
  253. int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type,
  254. const TensorType paddings_type)
  255. {
  256. if (block_shape_type.hasStaticShape())
  257. {
  258. return block_shape_type.getShape()[0];
  259. }
  260. else if (paddings_type.hasStaticShape())
  261. {
  262. return paddings_type.getShape()[0];
  263. }
  264. else
  265. {
  266. return -1;
  267. }
  268. }
  269. // Infers returned rank if possible. Further, infers returned dimension sizes
  270. // when possible. For all dimensions sizes to be inferred, the arguments
  271. // block_shape and paddings must be constant.
  272. LogicalResult SpaceToBatchNDOp::inferReturnTypes(
  273. MLIRContext* context, Optional<Location> location, ValueRange operands,
  274. DictionaryAttr attributes, RegionRange regions,
  275. SmallVectorImpl<Type>& inferredReturnTypes)
  276. {
  277. const Value input = operands[0];
  278. const Value block_shape_val = operands[1];
  279. const Value paddings_val = operands[2];
  280. const auto input_type = input.getType().cast<TensorType>();
  281. const auto block_shape_type = block_shape_val.getType().cast<TensorType>();
  282. const auto paddings_type = paddings_val.getType().cast<TensorType>();
  283. // The return is unranked when the input is unranked.
  284. if (!input_type.hasRank())
  285. {
  286. inferredReturnTypes.assign(
  287. {UnrankedTensorType::get(input_type.getElementType())});
  288. return success();
  289. }
  290. const int64_t input_rank = input_type.getRank();
  291. const ArrayRef<int64_t> input_shape = input_type.getShape();
  292. const int64_t block_rank = SpaceToBatchNDBlockRank(block_shape_type, paddings_type);
  293. SmallVector<int64_t, 4> return_shape(input_rank, ShapedType::kDynamicSize);
  294. // The return has all dimension sizes unknown when block_rank is unknown.
  295. if (block_rank == ShapedType::kDynamicSize)
  296. {
  297. inferredReturnTypes.assign(
  298. {RankedTensorType::get(return_shape, input_type.getElementType())});
  299. return success();
  300. }
  301. // The return preserves the remaining dimensions after blocked dimensions.
  302. for (uint64_t i = 1 + block_rank; i < input_rank; ++i)
  303. {
  304. return_shape[i] = input_shape[i];
  305. }
  306. // The rest of the dimension sizes can be calculated when block_shape and
  307. // paddings arguments are constant.
  308. ElementsAttr block_shape_attr;
  309. ElementsAttr paddings_attr;
  310. if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) && matchPattern(paddings_val, m_Constant(&paddings_attr)))
  311. {
  312. int64_t return_batch = input_shape[0];
  313. for (uint64_t i = 0; i < block_rank; ++i)
  314. {
  315. // Propagate dynamic dimension.
  316. if (input_shape[i + 1] == ShapedType::kDynamicSize)
  317. {
  318. return_batch = ShapedType::kDynamicSize;
  319. }
  320. if (return_batch == ShapedType::kDynamicSize)
  321. {
  322. return_shape[1 + i] = ShapedType::kDynamicSize;
  323. continue;
  324. }
  325. int64_t paddings_sum = paddings_attr.getValue({i, 0}).cast<IntegerAttr>().getInt() + paddings_attr.getValue({i, 1}).cast<IntegerAttr>().getInt();
  326. int64_t block_shape_i = block_shape_attr.getValue({i}).cast<IntegerAttr>().getInt();
  327. return_batch *= block_shape_i;
  328. return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i;
  329. }
  330. return_shape[0] = return_batch;
  331. }
  332. inferredReturnTypes.assign(
  333. {RankedTensorType::get(return_shape, input_type.getElementType())});
  334. return success();
  335. }
  336. Region& WhileRegionOp::getLoopBody()
  337. {
  338. return body();
  339. }
  340. bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
  341. {
  342. // If the Op defining the value exists and the defining op is outside the
  343. // scope of this WhileRegion, then we can infer that its defined outside.
  344. // The defining Op is outside the scope of this WhileRegion if this
  345. // WhileRegionOp is not an ancestor of the defining op in the parent chain.
  346. Operation* def_op = value.getDefiningOp();
  347. return def_op && !getOperation()->isAncestor(def_op);
  348. }
  349. LogicalResult WhileRegionOp::moveOutOfLoop(
  350. llvm::ArrayRef<mlir::Operation*> ops)
  351. {
  352. // Move the hoisted value to just before the while.
  353. Operation* while_op = this->getOperation();
  354. for (auto op : ops) op->moveBefore(while_op);
  355. return success();
  356. }
  357. } // namespace TF
  358. } // namespace mlir
  359. #define GET_OP_CLASSES
  360. #include "tf_all_ops.cc.inc"