Browse Source

adapt mlir changes

tags/20210322
nihuini 5 years ago
parent
commit
c8ccccf045
9 changed files with 2714 additions and 1422 deletions
  1. +5
    -7
      docs/how-to-build/build-mlir2ncnn.md
  2. +1
    -1
      tools/mlir/tf_attributes.cc
  3. +29
    -14
      tools/mlir/tf_dialect.cpp
  4. +2
    -2
      tools/mlir/tf_dialect.h
  5. +2631
    -1345
      tools/mlir/tf_generated_ops.td
  6. +12
    -13
      tools/mlir/tf_op_base.td
  7. +17
    -29
      tools/mlir/tf_ops.td
  8. +3
    -3
      tools/mlir/tf_types.cc
  9. +14
    -8
      tools/mlir/tf_types.h

+ 5
- 7
docs/how-to-build/build-mlir2ncnn.md View File

@@ -7,17 +7,15 @@
https://github.com/llvm/llvm-project.git
git checkout -b mlir <a_working_commit_id>
```
Current working commit id is 7c15e0f64ccc79a53ed2db258f1cb58ec452a957:
Current working commit id is 74e6030bcbcc8e628f9a99a424342a0c656456f9:
```
$ git log

commit 7c15e0f64ccc79a53ed2db258f1cb58ec452a957 (HEAD -> 01-26)
Author: MaheshRavishankar <ravishankarm@google.com>
Date: Tue Jan 26 23:21:33 2021 -0800
commit 74e6030bcbcc8e628f9a99a424342a0c656456f9 (HEAD -> main, origin/main, origin/HEAD)
Author: Craig Topper <craig.topper@sifive.com>
Date: Thu Mar 4 22:30:38 2021 -0800

[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
Differential Revision: https://reviews.llvm.org/D95305
[TargetLowering] Use HandleSDNodes to prevent nodes from being deleted by recursive calls in getNegatedExpression.
```

It is determined by query lastest git commit date of `tools/mlir` directory.


+ 1
- 1
tools/mlir/tf_attributes.cc View File

@@ -139,7 +139,7 @@ bool ShapeAttr::hasStaticShape() const
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
DictionaryAttr attr)
{
auto symbol = SymbolRefAttr::get(name, context);
auto symbol = SymbolRefAttr::get(context, name);
return Base::get(context, symbol, attr);
}



+ 29
- 14
tools/mlir/tf_dialect.cpp View File

@@ -178,8 +178,6 @@ Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
StringRef data;
if (parser.parseKeyword(&data)) return Type();

Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());

#define HANDLE_TF_TYPE(tftype, enumerant, name) \
if (data == name) return tftype##Type::get(getContext());
// Custom TensorFlow types are handled separately at the end as they do partial
@@ -188,15 +186,25 @@ Type TensorFlowDialect::parseType(DialectAsmParser& parser) const
// NOLINTNEXTLINE
#include "tf_types.def"

if (data.startswith("resource")) return ParseResourceType(parser, loc);
if (data.startswith("variant")) return ParseVariantType(parser, loc);
return (emitError(loc, "unknown TensorFlow type: " + data), nullptr);
llvm::SMLoc loc = parser.getNameLoc();
if (data.startswith("resource"))
{
Type ret = ParseResourceType(parser);
if (!ret) parser.emitError(loc, "invalid resource type");
return ret;
}
if (data.startswith("variant"))
{
Type ret = ParseVariantType(parser);
if (!ret) parser.emitError(loc, "invalid variant type");
return ret;
}
return (parser.emitError(loc, "unknown TensorFlow type: " + data), nullptr);
}

namespace {
template<typename TypeWithSubtype>
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser,
Location loc)
Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser)
{
// Default type without inferred subtypes.
if (failed(parser.parseOptionalLess())) return TypeWithSubtype::get(context);
@@ -207,24 +215,31 @@ Type ParseTypeWithSubtype(MLIRContext* context, DialectAsmParser& parser,
{
TensorType tensor_ty;
if (parser.parseType(tensor_ty)) return Type();

// Each of the subtypes should be a valid TensorFlow type.
// TODO(jpienaar): Remove duplication.
if (!IsValidTFTensorType(tensor_ty))
{
parser.emitError(parser.getNameLoc()) << "invalid subtype: " << tensor_ty;
return Type();
}
subtypes.push_back(tensor_ty);
} while (succeeded(parser.parseOptionalComma()));

if (parser.parseGreater()) return Type();
return TypeWithSubtype::getChecked(subtypes, context, loc);

return TypeWithSubtype::get(subtypes, context);
}
} // anonymous namespace

Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser,
Location loc) const
Type TensorFlowDialect::ParseResourceType(DialectAsmParser& parser) const
{
return ParseTypeWithSubtype<ResourceType>(getContext(), parser, loc);
return ParseTypeWithSubtype<ResourceType>(getContext(), parser);
}

Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser,
Location loc) const
Type TensorFlowDialect::ParseVariantType(DialectAsmParser& parser) const
{
return ParseTypeWithSubtype<VariantType>(getContext(), parser, loc);
return ParseTypeWithSubtype<VariantType>(getContext(), parser);
}

Operation* TensorFlowDialect::materializeConstant(OpBuilder& builder,


+ 2
- 2
tools/mlir/tf_dialect.h View File

@@ -47,11 +47,11 @@ public:
Type parseType(DialectAsmParser& parser) const override;

// Parses resource type with potential subtypes.
Type ParseResourceType(DialectAsmParser& parser, Location loc) const;
Type ParseResourceType(DialectAsmParser& parser) const;

// Parse and print variant type. It may have subtypes inferred using shape
// inference.
Type ParseVariantType(DialectAsmParser& parser, Location loc) const;
Type ParseVariantType(DialectAsmParser& parser) const;

// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.


+ 2631
- 1345
tools/mlir/tf_generated_ops.td
File diff suppressed because it is too large
View File


+ 12
- 13
tools/mlir/tf_op_base.td View File

@@ -24,7 +24,6 @@ limitations under the License.

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//include "tf_op_interfaces.td"

//===----------------------------------------------------------------------===//
// TensorFlow dialect definitions
@@ -477,13 +476,13 @@ class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
ArrayAttr::get($_ctx,
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}())
}]
>;

@@ -497,13 +496,13 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};",
[{
ArrayAttr::get(
ArrayAttr::get($_ctx,
[&](){
llvm::SmallVector<Attribute, 4> ret;
for (auto shape : $_self)
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
return ret;
}(), $_ctx)
}())
}]
>;

@@ -533,13 +532,13 @@ class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
ArrayAttr::get($_ctx,
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}())
}]
>;

@@ -553,13 +552,13 @@ class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
"return {mlir::TF::ResultShapeIterator(values.begin()), "
"mlir::TF::ResultShapeIterator(values.end())};",
[{
ArrayAttr::get(
ArrayAttr::get($_ctx,
[&](){
llvm::SmallVector<Attribute, 4> ret;
for (auto shape : $_self)
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
return ret;
}(), $_ctx)
}())
}]
>;

@@ -579,8 +578,8 @@ def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
// Mixin class defining a builder for binary ops supporting broadcast
// behavior. The result type has the same element type as both operands.
class WithBroadcastableBinOpBuilder {
list<OpBuilderDAG> builders = [
OpBuilderDAG<(ins "Value":$x, "Value":$y),
list<OpBuilder> builders = [
OpBuilder<(ins "Value":$x, "Value":$y),
[{
auto resultType =
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
@@ -593,8 +592,8 @@ class WithBroadcastableBinOpBuilder {
// Mixin class defining a builder for comparison ops supporting broadcast
// behavior. The result type has bool element type.
class WithBroadcastableCmpOpBuilder {
list<OpBuilderDAG> builders = [
OpBuilderDAG<(ins "Value":$x, "Value":$y),
list<OpBuilder> builders = [
OpBuilder<(ins "Value":$x, "Value":$y),
[{
Type resultType;
if (x.getType().isa<UnrankedTensorType>() ||


+ 17
- 29
tools/mlir/tf_ops.td View File

@@ -34,6 +34,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td"

class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
@@ -112,7 +113,6 @@ An n-way switch statement, implementing the following:
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;


let verifier = [{
return Verify(*this);
}];
@@ -179,7 +179,6 @@ An n-way switch statement, implementing the following:
return Verify(*this);
}];


}

// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
@@ -198,8 +197,8 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> {
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;

let builders = [
OpBuilderDAG<(ins "Attribute":$value)>,
OpBuilderDAG<(ins "Type":$type, "Attribute":$value)>,
OpBuilder<(ins "Attribute":$value)>,
OpBuilder<(ins "Type":$type, "Attribute":$value)>,
];
}

@@ -322,11 +321,6 @@ else_branch: A function that takes 'inputs' and returns a list of
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;

let verifier = [{
return Verify(*this);
}];


let extraClassDeclaration = [{
// Get the then branch function.
FuncOp then_function() {
@@ -381,7 +375,12 @@ else_branch: A region that computes the outputs of the op if cond = false.
0DTensorOf<[I1]>:$cond,

// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless
BoolAttr:$is_stateless,
// Used to maintain function name when round-tripping
// between functional and regional control flow. This can be removed if
// the runtime does not require globally unique then/else branch function names.
OptionalAttr<StrAttr>:$_then_func_name,
OptionalAttr<StrAttr>:$_else_func_name
);

let results = (outs
@@ -395,14 +394,13 @@ else_branch: A region that computes the outputs of the op if cond = false.
}];

let builders = [
OpBuilderDAG<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
"llvm::ArrayRef<::mlir::NamedAttribute>":$attributes,
"unsigned":$numRegions),
[{
assert(numRegions == 2u && "mismatched number of regions");
build($_builder, $_state, resultTypes, operands, attributes);
}]>];

}

def TF_LegacyCallOp : TF_Op<"LegacyCall",
@@ -699,10 +697,6 @@ body: A function that takes a list of tensors and returns another
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;

let verifier = [{
return Verify(*this);
}];

let extraClassDeclaration = [{
// Get the condition function.
FuncOp cond_function() {
@@ -776,7 +770,6 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);

let verifier = [{ return Verify(*this); }];

}

def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
@@ -825,7 +818,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
let results = (outs);
}

def TF_VarHandleOp : TF_Op<"VarHandleOp"> {
def TF_VarHandleOp : TF_Op<"VarHandleOp", []> {
let summary = "Creates a handle to a Variable resource from its name.";

let description = [{
@@ -967,6 +960,7 @@ An op which shards the input based on the given sharding attribute.
let arguments = (ins
TF_Tensor:$input,

DefaultValuedAttr<StrAttr, "">:$sharding,
OptionalAttr<StrAttr>:$_XlaSharding
);

@@ -1168,12 +1162,11 @@ as true/false for a branch condition.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;

let builders = [
OpBuilderDAG<(ins "Value":$value),
OpBuilder<(ins "Value":$value),
[{
build($_builder, $_state, RankedTensorType::get({}, $_builder.getI1Type()),
value);
}]>];

}

def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
@@ -1333,17 +1326,15 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
}];

let arguments = (ins
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$y
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$y
);

let results = (outs
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint32]>:$z
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>:$z
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;


}

def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
@@ -1409,8 +1400,6 @@ If `x` and `y` are reals, this will return the floating-point division.
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;


}

def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
@@ -1436,7 +1425,6 @@ Both input and output have a range `(-inf, inf)`.
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;

}

def TF_StatefulStandardNormalV2Op : TF_Op<"StatefulStandardNormalV2", []> {
@@ -1665,7 +1653,7 @@ event: A string containing a binary-encoded tf.Event proto.
let results = (outs);
}

def TF_SummaryWriterOp : TF_Op<"SummaryWriter"> {
def TF_SummaryWriterOp : TF_Op<"SummaryWriter", []> {
let summary = "Returns a handle to be used to access a summary writer.";

let description = [{


+ 3
- 3
tools/mlir/tf_types.cc View File

@@ -230,7 +230,7 @@ ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes()

// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
// similar structure that could be extracted into helper method.
bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs)
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs)
{
if (lhs.size() != rhs.size()) return false;
for (auto types : llvm::zip(lhs, rhs))
@@ -395,7 +395,7 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
}

bool AreCastCompatible(ArrayRef<Type> types)
bool AreCastCompatible(TypeRange types)
{
Type common = types.front();
for (auto type : types.drop_front())
@@ -407,7 +407,7 @@ bool AreCastCompatible(ArrayRef<Type> types)
return true;
}

bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs)
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs)
{
if (lhs.size() != rhs.size()) return false;
for (auto pair : llvm::zip(lhs, rhs))


+ 14
- 8
tools/mlir/tf_types.h View File

@@ -123,14 +123,14 @@ public:
static TensorFlowType getChecked(Type type, MLIRContext* context,
Location loc)
{
if (failed(verifyConstructionInvariants(loc, type)))
if (failed(verify(loc, type)))
{
return TensorFlowRefType();
}
return get(type);
}

static LogicalResult verifyConstructionInvariants(Location loc, Type type)
static LogicalResult verify(Location loc, Type type)
{
// type should be a valid TensorFlow type.
if (!IsValidTFTensorType(type))
@@ -237,21 +237,27 @@ public:
{
return Base::getChecked(loc, subtypes);
}
static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext* context,
ArrayRef<TensorType> subtypes)
{
return Base::getChecked(emitError, context, subtypes);
}

static Derived get(MLIRContext* context)
{
return get({}, context);
}

static LogicalResult verifyConstructionInvariants(
Location loc, ArrayRef<TensorType> subtypes)
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<TensorType> subtypes)
{
// Each of the subtypes should be a valid TensorFlow type.
for (TensorType subtype : subtypes)
{
if (!IsValidTFTensorType(subtype))
{
return emitError(loc) << "invalid " << Derived::getTypeName()
return emitError() << "invalid " << Derived::getTypeName()
<< " subtype: " << subtype;
}
}
@@ -328,7 +334,7 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a);

// Returns whether two arrays of Type are broadcast compatible.
bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);

// Returns whether the two elemental types are compatible. Shapes are compatible
// if:
@@ -346,11 +352,11 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
// compatible.
bool AreCastCompatible(ArrayRef<Type> types);
bool AreCastCompatible(TypeRange types);

// Returns true if corresponding elements of lhs and rhs AreCastCompatible and
// lhs and rhs are the same length.
bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);

// If `ty` is a tensor type and its element type has subtypes, then returns a
// new type of same shape but dropped subtypes for the element type.


Loading…
Cancel
Save