You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_dialect.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "tf_dialect.h"
  15. #include <mlir/Dialect/Traits.h>
  16. #include <mlir/IR/Attributes.h>
  17. #include <mlir/IR/Builders.h>
  18. #include <mlir/IR/Dialect.h>
  19. #include <mlir/IR/DialectImplementation.h>
  20. #include <mlir/IR/Location.h>
  21. #include <mlir/IR/Matchers.h>
  22. #include <mlir/IR/MLIRContext.h>
  23. #include <mlir/IR/OpDefinition.h>
  24. #include <mlir/IR/OpImplementation.h>
  25. #include <mlir/IR/Operation.h>
  26. #include <mlir/IR/OperationSupport.h>
  27. #include <mlir/IR/PatternMatch.h>
  28. #include <mlir/IR/TypeUtilities.h>
  29. #include <mlir/IR/Types.h>
  30. #include <mlir/IR/Value.h>
  31. #include <mlir/IR/Verifier.h>
  32. #include <mlir/Interfaces/CallInterfaces.h>
  33. #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
  34. #include <mlir/Interfaces/InferTypeOpInterface.h>
  35. #include <mlir/Interfaces/LoopLikeInterface.h>
  36. #include <mlir/Interfaces/SideEffectInterfaces.h>
  37. #include <mlir/Parser.h>
  38. #include <mlir/Support/LogicalResult.h>
  39. #include <mlir/Transforms/InliningUtils.h>
  40. #include "tf_attributes.h"
  41. #include "tf_side_effects.h"
  42. #include "tf_traits.h"
  43. namespace mlir {
  44. static LogicalResult Verify(...)
  45. {
  46. return success();
  47. }
  48. static LogicalResult VerifyPartitionedCall(...)
  49. {
  50. return success();
  51. }
  52. static LogicalResult VerifyStridedSliceBase(...)
  53. {
  54. return success();
  55. }
  56. static LogicalResult VerifyUnsortedSegmentReduction(...)
  57. {
  58. return success();
  59. }
  60. namespace TF {
  61. TensorFlowDialect::TensorFlowDialect(MLIRContext* context)
  62. : Dialect(/*name=*/"tf", context, TypeID::get<TensorFlowDialect>())
  63. {
  64. addOperations<
  65. #define GET_OP_LIST
  66. #include "tf_all_ops.cc.inc"
  67. >();
  68. addTypes<
  69. #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type,
  70. #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
  71. #include "tf_types.def"
  72. >();
  73. // addInterfaces<TFInlinerInterface, TFDecodeAttributesInterface,
  74. // TFConstantFoldInterface>();
  75. addAttributes<ShapeAttr, FuncAttr>();
  76. // Support unknown operations because not all TensorFlow operations are
  77. // registered.
  78. allowUnknownOperations();
  79. // for (const auto &hook : *TensorFlowDialect::additional_operation_hooks_) {
  80. // hook(*this);
  81. // }
  82. }
  83. namespace {
  84. ShapeAttr ParseShapeAttr(MLIRContext* context, StringRef spec, Location loc)
  85. {
  86. auto emit_error = [&, spec]() {
  87. emitError(loc, "invalid TensorFlow shape attribute: ") << spec;
  88. return nullptr;
  89. };
  90. if (!spec.consume_front("shape<")) return emit_error();
  91. if (spec.consume_front("*>"))
  92. return mlir::TF::ShapeAttr::get(context, llvm::None);
  93. SmallVector<int64_t, 4> shape;
  94. while (!spec.consume_front(">"))
  95. {
  96. int64_t dim;
  97. if (spec.consume_front("?"))
  98. dim = -1;
  99. else if (spec.consumeInteger(10, dim) || dim < 0)
  100. return emit_error();
  101. spec.consume_front("x");
  102. shape.push_back(dim);
  103. }
  104. return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape));
  105. }
  106. // Parses a #tf.func attribute of the following format:
  107. //
  108. // #tf.func<@symbol, {attr = "value"}>
  109. //
  110. // where the first element is a SymbolRefAttr and the second element is a
  111. // DictionaryAttr.
  112. FuncAttr ParseFuncAttr(MLIRContext* context, StringRef spec, Location loc)
  113. {
  114. auto emit_error = [&, spec]() {
  115. emitError(loc, "invalid TensorFlow func attribute: ") << spec;
  116. return nullptr;
  117. };
  118. if (!spec.consume_front("func<")) return emit_error();
  119. size_t func_name_num_read = 0;
  120. Attribute func_name_attr = mlir::parseAttribute(spec, context, func_name_num_read);
  121. if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
  122. return emit_error();
  123. spec = spec.drop_front(func_name_num_read);
  124. if (!spec.consume_front(", ")) return emit_error();
  125. size_t func_attrs_num_read = 0;
  126. Attribute func_attrs_attr = mlir::parseAttribute(spec, context, func_attrs_num_read);
  127. if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
  128. return emit_error();
  129. spec = spec.drop_front(func_attrs_num_read);
  130. if (!spec.consume_front(">")) return emit_error();
  131. return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
  132. func_attrs_attr.cast<DictionaryAttr>());
  133. }
  134. } // namespace
  135. Attribute TensorFlowDialect::parseAttribute(DialectAsmParser& parser,
  136. Type type) const
  137. {
  138. auto spec = parser.getFullSymbolSpec();
  139. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  140. if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
  141. if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
  142. return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
  143. }
  144. // Parses a type registered to this dialect.
  145. Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
  146. {
  147. StringRef data;
  148. if (parser.parseKeyword(&data)) return Type();
  149. Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
  150. #define HANDLE_TF_TYPE(tftype, enumerant, name) \
  151. if (data == name) return tftype##Type::get(getContext());
  152. // Custom TensorFlow types are handled separately at the end as they do partial
  153. // match.
  154. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
  155. // NOLINTNEXTLINE
  156. #include "tf_types.def"
  157. if (data.startswith("resource")) return ParseResourceType(parser, loc);
  158. if (data.startswith("variant")) return ParseVariantType(parser, loc);
  159. return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
  160. }
  161. namespace {
  162. template<typename TypeWithSubtype>
  163. Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser,
  164. Location loc)
  165. {
  166. // Default type without inferred subtypes.
  167. if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
  168. // Most types with subtypes have only one subtype.
  169. SmallVector<TensorType, 1> subtypes;
  170. do
  171. {
  172. TensorType tensor_ty;
  173. if (parser.parseType(tensor_ty)) return Type();
  174. subtypes.push_back(tensor_ty);
  175. } while (succeeded(parser.parseOptionalComma()));
  176. if (parser.parseGreater()) return Type();
  177. return TypeWithSubtype::getChecked(subtypes, context, loc);
  178. }
  179. } // anonymous namespace
  180. Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser,
  181. Location loc) const
  182. {
  183. return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
  184. }
  185. Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser,
  186. Location loc) const
  187. {
  188. return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
  189. }
  190. Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,
  191. Attribute value, Type type,
  192. Location loc)
  193. {
  194. return builder.create<ConstOp>(loc, type, value);
  195. }
  196. // Builds a constant op with the specified attribute `value`. The result
  197. // op's type is deduced from `value`; if `value` is of scalar type,
  198. // wraps it up with a tensor type of empty shape.
  199. // TODO(jpienaar): This one differs from the autogenerated one as it takes an
  200. // attribute but always creates an ElementsAttr internally.
  201. void ConstOp::build(OpBuilder& builder, OperationState& result,
  202. Attribute value)
  203. {
  204. ShapedType type;
  205. if (auto elem_attr = value.dyn_cast<ElementsAttr>())
  206. {
  207. return ConstOp::build(builder, result, elem_attr);
  208. }
  209. else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>())
  210. {
  211. // All TensorFlow types must be tensor types. In the build() method,
  212. // we want to provide more flexibility by allowing attributes of scalar
  213. // types. But we need to wrap it up with ElementsAttr to construct
  214. // valid TensorFlow constants.
  215. type = RankedTensorType::get(/*shape=*/ {}, value.getType());
  216. return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
  217. }
  218. // TODO(jpienaar): support other TensorFlow specific types.
  219. llvm_unreachable("unsupported attribute type for building tf.Const");
  220. }
  221. void ConstOp::build(OpBuilder& builder, OperationState& result, Type type,
  222. Attribute value)
  223. {
  224. // Handle the case where the type and value are already tensors.
  225. if (type.isa<TensorType>() && value.isa<ElementsAttr>())
  226. {
  227. result.addTypes(type);
  228. result.addAttribute("value", value);
  229. return;
  230. }
  231. // Otherwise, default to the attribute builder.
  232. ConstOp::build(builder, result, value);
  233. assert(type == result.types[0] && "type mismatch in construction");
  234. }
  235. Region& WhileRegionOp::getLoopBody()
  236. {
  237. return body();
  238. }
  239. bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
  240. {
  241. // If the Op defining the value exists and the defining op is outside the
  242. // scope of this WhileRegion, then we can infer that its defined outside.
  243. // The defining Op is outside the scope of this WhileRegion if this
  244. // WhileRegionOp is not an ancestor of the defining op in the parent chain.
  245. Operation* def_op = value.getDefiningOp();
  246. return def_op && !getOperation()->isAncestor(def_op);
  247. }
  248. LogicalResult WhileRegionOp::moveOutOfLoop(
  249. llvm::ArrayRef<mlir::Operation*> ops)
  250. {
  251. // Move the hoisted value to just before the while.
  252. Operation* while_op = this->getOperation();
  253. for (auto op : ops) op->moveBefore(while_op);
  254. return success();
  255. }
  256. } // namespace TF
  257. } // namespace mlir
  258. #define GET_OP_CLASSES
  259. #include "tf_all_ops.cc.inc"