You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_dialect.cpp 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright 2020 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "tf_dialect.h"
  4. #include <mlir/Dialect/Traits.h>
  5. #include <mlir/IR/Attributes.h>
  6. #include <mlir/IR/Builders.h>
  7. #include <mlir/IR/Dialect.h>
  8. #include <mlir/IR/DialectImplementation.h>
  9. #include <mlir/IR/Location.h>
  10. #include <mlir/IR/Matchers.h>
  11. #include <mlir/IR/MLIRContext.h>
  12. #include <mlir/IR/OpDefinition.h>
  13. #include <mlir/IR/OpImplementation.h>
  14. #include <mlir/IR/Operation.h>
  15. #include <mlir/IR/OperationSupport.h>
  16. #include <mlir/IR/PatternMatch.h>
  17. #include <mlir/IR/TypeUtilities.h>
  18. #include <mlir/IR/Types.h>
  19. #include <mlir/IR/Value.h>
  20. #include <mlir/IR/Verifier.h>
  21. #include <mlir/Interfaces/CallInterfaces.h>
  22. #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
  23. #include <mlir/Interfaces/InferTypeOpInterface.h>
  24. #include <mlir/Interfaces/LoopLikeInterface.h>
  25. #include <mlir/Interfaces/SideEffectInterfaces.h>
  26. #include <mlir/Parser.h>
  27. #include <mlir/Support/LogicalResult.h>
  28. #include <mlir/Transforms/InliningUtils.h>
  29. #include "tf_attributes.h"
  30. #include "tf_side_effects.h"
  31. #include "tf_traits.h"
  32. namespace mlir {
  33. static LogicalResult Verify(...)
  34. {
  35. return success();
  36. }
  37. static LogicalResult VerifyPartitionedCall(...)
  38. {
  39. return success();
  40. }
  41. static LogicalResult VerifyStridedSliceBase(...)
  42. {
  43. return success();
  44. }
  45. static LogicalResult VerifyUnsortedSegmentReduction(...)
  46. {
  47. return success();
  48. }
  49. namespace TF {
  50. TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
  51. : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
  52. {
  53. addOperations<
  54. #define GET_OP_LIST
  55. #include "tf_all_ops.cc.inc"
  56. >();
  57. addTypes<
  58. #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
  59. #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
  60. #include "tf_types.def"
  61. >();
  62. // addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
  63. // TFConstantFoldInterface>();
  64. addAttributes<ShapeAttr, FuncAttr>();
  65. // Support unknown operations because not all TensorFlow operations are
  66. // registered.
  67. allowUnknownOperations();
  68. // for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
  69. // hook(*this);
  70. // }
  71. }
  72. namespace {
  73. ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
  74. {
  75. auto emit_error = [&, spec]() {
  76. emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
  77. return nullptr;
  78. };
  79. if (!spec.consume_front("shape<")) return emit_error();
  80. if (spec.consume_front("*>"))
  81. return mlir::TF::ShapeAttr::get(context, llvm::None);
  82. SmallVector<int64_t, 4> shape;
  83. while (!spec.consume_front(">"))
  84. {
  85. int64_t dim;
  86. if (spec.consume_front("?"))
  87. dim = -1;
  88. else if (spec.consumeInteger(10, dim) || dim < 0)
  89. return emit_error();
  90. spec.consume_front("x");
  91. shape.push_back(dim);
  92. }
  93. return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
  94. }
  95. // Parses a #tf.func attribute of the following format:
  96. //
  97. // #tf.func<@symbol, {attr = "value"}>
  98. //
  99. // where the first element is a SymbolRefAttr and the second element is a
  100. // DictionaryAttr.
  101. FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
  102. {
  103. auto emit_error = [&, spec]() {
  104. emitError(loc, "invalid TensorFlow func attribute: ") << spec;
  105. return nullptr;
  106. };
  107. if (!spec.consume_front("func<")) return emit_error();
  108. size_t func_name_num_read = 0;
  109. Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
  110. if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
  111. return emit_error();
  112. spec = spec.drop_front(func_name_num_read);
  113. if (!spec.consume_front(", ")) return emit_error();
  114. size_t func_attrs_num_read = 0;
  115. Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
  116. if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
  117. return emit_error();
  118. spec = spec.drop_front(func_attrs_num_read);
  119. if (!spec.consume_front(">")) return emit_error();
  120. return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
  121. func_attrs_attr.cast<DictionaryAttr>());
  122. }
  123. } // namespace
  124. Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
  125. Type type) const
  126. {
  127. auto spec = parser.getFullSymbolSpec();
  128. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  129. if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
  130. if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
  131. return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
  132. }
  133. // Parses a type registered to this dialect.
  134. Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
  135. {
  136. StringRef data;
  137. if (parser.parseKeyword(&data)) return Type();
  138. #define HANDLE_TF_TYPE(tftype, enumerant, name) \
  139. if (data == name) return tftype##Type::get(getContext());
  140. // Custom TensorFlow types are handled separately at the end as they do partial
  141. // match.
  142. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
  143. // NOLINTNEXTLINE
  144. #include "tf_types.def"
  145. llvm::SMLoc loc = parser.getNameLoc();
  146. if (data.startswith("resource"))
  147. {
  148. Type ret = ParseResourceType(parser);
  149. if (!ret) parser.emitError(loc, "invalid resource type");
  150. return ret;
  151. }
  152. if (data.startswith("variant"))
  153. {
  154. Type ret = ParseVariantType(parser);
  155. if (!ret) parser.emitError(loc, "invalid variant type");
  156. return ret;
  157. }
  158. return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
  159. }
  160. namespace {
  161. template<typename TypeWithSubtype>
  162. Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
  163. {
  164. // Default type without inferred subtypes.
  165. if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
  166. // Most types with subtypes have only one subtype.
  167. SmallVector<TensorType, 1> subtypes;
  168. do
  169. {
  170. TensorType tensor_ty;
  171. if (parser.parseType(tensor_ty)) return Type();
  172. // Each of the subtypes should be a valid TensorFlow type.
  173. // TODO(jpienaar): Remove duplication.
  174. if (!IsValidTFTensorType(tensor_ty))
  175. {
  176. parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
  177. return Type();
  178. }
  179. subtypes.push_back(tensor_ty);
  180. } while (succeeded(parser.parseOptionalComma()));
  181. if (parser.parseGreater()) return Type();
  182. return TypeWithSubtype::get(subtypes, context);
  183. }
  184. } // anonymous namespace
  185. Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
  186. {
  187. return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
  188. }
  189. Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
  190. {
  191. return ParseTypeWithSubtype<VariantType>(getContext(), parser);
  192. }
  193. Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
  194. Attribute value, Type type,
  195. Location loc)
  196. {
  197. return builder.create<ConstOp>(loc, type, value);
  198. }
  199. // Builds a constant op with the specified attribute `value`. The result
  200. // op's type is deduced from `value`; if `value` is of scalar type,
  201. // wraps it up with a tensor type of empty shape.
  202. // TODO(jpienaar): This one differs from the autogenerated one as it takes an
  203. // attribute but always creates an ElementsAttr internally.
  204. void ConstOp::build(OpBuilder& builder, OperationState& result,
  205. Attribute value)
  206. {
  207. ShapedType type;
  208. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  209. {
  210. return ConstOp::build(builder, result, elem_attr);
  211. }
  212. else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
  213. {
  214. // All TensorFlow types must be tensor types. In the build() method,
  215. // we want to provide more flexibility by allowing attributes of scalar
  216. // types. But we need to wrap it up with ElementsAttr to construct
  217. // valid TensorFlow constants.
  218. type = RankedTensorType::get(/*shape=*/ {}, value.getType());
  219. return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
  220. }
  221. // TODO(jpienaar): support other TensorFlow specific types.
  222. llvm_unreachable("unsupported attribute type for building tf.Const");
  223. }
  224. void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
  225. Attribute value)
  226. {
  227. // Handle the case where the type and value are already tensors.
  228. if (type.isa<TensorType>() && value.isa<ElementsAttr>())
  229. {
  230. result.addTypes(type);
  231. result.addAttribute("value", value);
  232. return;
  233. }
  234. // Otherwise, default to the attribute builder.
  235. ConstOp::build(builder, result, value);
  236. assert(type == result.types[0] && "type mismatch in construction");
  237. }
  238. Region& WhileRegionOp::getLoopBody()
  239. {
  240. return body();
  241. }
  242. bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
  243. {
  244. // If the Op defining the value exists and the defining op is outside the
  245. // scope of this WhileRegion, then we can infer that its defined outside.
  246. // The defining Op is outside the scope of this WhileRegion if this
  247. // WhileRegionOp is not an ancestor of the defining op in the parent chain.
  248. Operation* def_op = value.getDefiningOp();
  249. return def_op && !getOperation()->isAncestor(def_op);
  250. }
  251. LogicalResult WhileRegionOp::moveOutOfLoop(
  252. llvm::ArrayRef<mlir::Operation*> ops)
  253. {
  254. // Move the hoisted value to just before the while.
  255. Operation* while_op = this->getOperation();
  256. for (auto op : ops) op->moveBefore(while_op);
  257. return success();
  258. }
  259. } // namespace TF
  260. } // namespace mlir
  261. #define GET_OP_CLASSES
  262. #include "tf_all_ops.cc.inc"