Browse Source

build mlir2ncnn with latest mlir with updated tf2 dialect (#1925)

tags/20200727
nihui GitHub 6 years ago
parent
commit
afb09cccbb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1445 additions and 402 deletions
  1. +10
    -0
      tools/mlir/fix_td.sh
  2. +26
    -0
      tools/mlir/mlir2ncnn.cpp
  3. +1141
    -382
      tools/mlir/tf_generated_ops.td
  4. +20
    -0
      tools/mlir/tf_op_base.td
  5. +191
    -18
      tools/mlir/tf_ops.td
  6. +54
    -0
      tools/mlir/tf_side_effects.h
  7. +3
    -2
      tools/mlir/tf_types.h

+ 10
- 0
tools/mlir/fix_td.sh View File

@@ -0,0 +1,10 @@
#!/bin/sh

# This dirty script eat td files :P
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tensorflow/ir

sed -i '/let hasCanonicalizer = 1;/d' *.td
sed -i '/let hasFolder = 1;/d' *.td
sed -i '/StringRef GetOptimalLayout(const RuntimeDevices& devices);/d' *.td
sed -i '/LogicalResult UpdateDataFormat(StringRef data_format);/d' *.td
sed -i '/LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);/d' *.td

+ 26
- 0
tools/mlir/mlir2ncnn.cpp View File

@@ -48,12 +48,14 @@
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Parser.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/InliningUtils.h>

#include "tf_attributes.h"
#include "tf_side_effects.h"
#include "tf_traits.h"

namespace mlir {
@@ -384,6 +386,30 @@ LogicalResult ConstOp::inferReturnTypes(
"constant vector/tensor");
}

Region& WhileRegionOp::getLoopBody()
{
return body();
}

bool WhileRegionOp::isDefinedOutsideOfLoop(Value value)
{
// If the Op defining the value exists and the defining op is outside the
// scope of this WhileRegion, then we can infer that its defined outside.
// The defining Op is outside the scope of this WhileRegion if this
// WhileRegionOp is not an ancestor of the defining op in the parent chain.
Operation* def_op = value.getDefiningOp();
return def_op && !getOperation()->isAncestor(def_op);
}

LogicalResult WhileRegionOp::moveOutOfLoop(
llvm::ArrayRef<mlir::Operation*> ops)
{
// Move the hoisted value to just before the while.
Operation* while_op = this->getOperation();
for (auto op : ops) op->moveBefore(while_op);
return success();
}

} // namespace TF

} // namespace mlir


+ 1141
- 382
tools/mlir/tf_generated_ops.td
File diff suppressed because it is too large
View File


+ 20
- 0
tools/mlir/tf_op_base.td View File

@@ -80,6 +80,26 @@ class TF_AllTypesMatch<list<string> names> :
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;

//===----------------------------------------------------------------------===//
// TensorFlow op side effects
//===----------------------------------------------------------------------===//

class TF_ResourceBase<string resourceKind> :
Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
}

def TF_VariableResource : TF_ResourceBase<"Variable">;
def TF_StackResource : TF_ResourceBase<"Stack">;
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;

def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;

def TF_VariableWrite : MemWrite<TF_VariableResource>;
def TF_StackWrite : MemWrite<TF_StackResource>;
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;

//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//


+ 191
- 18
tools/mlir/tf_ops.td View File

@@ -31,6 +31,7 @@ include "tf_generated_ops.td"
include "tf_op_base.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"

class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
@@ -90,7 +91,6 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
"OpBuilder &builder, OperationState &result, Type type, Attribute value">,
];

// let hasFolder = 1;

let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
@@ -99,6 +99,30 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect,
}];
}

def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", []> {
let summary = "An Op to permute tensors across replicated TPU instances.";

let description = [{
Each instance supplies its own input.

For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
`[D, A, B, C]`.
}];

let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
I32Tensor:$source_target_pairs
);

let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}


def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Permute input tensor from `src_format` to `dst_format`";

@@ -206,8 +230,10 @@ else_branch: A function that takes 'inputs' and returns a list of
}];
}

def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
def TF_YieldOp : TF_Op<"Yield",
[Terminator, ParentOneOf<["IfRegionOp", "WhileRegionOp"]>]> {
let summary = "Yield operation";

let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., if and while). The operation
@@ -217,10 +243,6 @@ def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
}];

let arguments = (ins Variadic<AnyType>:$operands);

let verifier = [{
return Verify(*this);
}];
}

def TF_IfRegionOp : TF_Op<"IfRegion",
@@ -294,7 +316,6 @@ retained with length 1.
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
}];
}

@@ -334,6 +355,41 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall",
}];
}

def TF_ParseExampleOp : TF_Op<"ParseExample",
[NoSideEffect,
AttrSizedResultSegments,
AttrSizedOperandSegments]> {

let summary =
"Transforms a vector of tf.Example protos (as strings) into typed tensors.";

let arguments = (ins
TF_StrTensor:$serialized,
TF_StrTensor:$names,
Variadic<TF_StrTensor>:$sparse_keys,
Variadic<TF_StrTensor>:$dense_keys,
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,

TF_ShapeAttrArray:$dense_shapes,
I32ElementsAttr:$result_segment_sizes,
I32ElementsAttr:$operand_segment_sizes
);

let results = (outs
Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values // len(Tdense)
);

TF_DerivedOperandSizeAttr Nsparse = TF_DerivedOperandSizeAttr<2>;
TF_DerivedOperandSizeAttr Ndense = TF_DerivedOperandSizeAttr<3>;
TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<4>;
TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;

let verifier = ?;
}

def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
[NoSideEffect,
AttrSizedResultSegments]> {
@@ -438,6 +494,7 @@ Inserts a placeholder for a tensor that will be always fed.

def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> {
let summary = "Placeholder op";

let description = [{
A placeholder op that passes through input when its output is not fed.
}];
@@ -534,7 +591,7 @@ output = input; While (Cond(output)) { output = Body(output) }

input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A function takes 'input' and returns a tensor. If the tensor is
cond: A function that takes 'input' and returns a tensor. If the tensor is
a scalar of non-boolean, the scalar is converted to a boolean
according to the following rule: if the scalar is a numerical
value, non-zero means True and zero means False; if the scalar is
@@ -570,6 +627,58 @@ body: A function that takes a list of tensors and returns another
}];
}

def TL_WhileRegionOp : TF_Op<"WhileRegion",
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "while operation";
let description = [{
The tf.WhileRegion op represents a while loop using 2 regions and a set of
iteration variables. The iteration variables maintained by this Op have the
same types as the inputs. The Op executes a while loop described by the
following pseudo code:

```
func WhileRegionOp(inputs) {
iteration_vars = inputs;
while (cond(iteration_vars)) {
iteration_vars = body(iteration_vars);
}
return iteration_vars;
}
```

`cond` is the condition region and `body` is the body region. Both these
regions accept the current value of the iteration variables as inputs. The
condition region returns a tensor<i1> which, if false, will exit the loop.
The body region computes new values of the iteration variables. The iteration
variables are initialized to the Op input, and the results of the
tf.WhileRegion op are the final values of the iteration variables.

This implies that the operand and result types for tf.WhileRegion should be
the same. Note that the condition and body regions can implicitly capture
loop invariant values directly. In canonical form, iteration variables that
pass through the loop body unmodified are converted to implicitly captured
references to their values outside the loop.
}];

let arguments = (ins
Variadic<AnyTensor>:$input,

// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
);
let results = (outs Variadic<AnyTensor>:$output);

TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;

let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);

let verifier = [{ return Verify(*this); }];

}

def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
let summary = "List of the given size with empty elements.";

@@ -609,7 +718,8 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
DefaultValuedAttr<StrArrayAttr, "{}">:$host_compute_core,
DefaultValuedAttr<StrArrayAttr, "{}">:$padding_map,
DefaultValuedAttr<StrAttr, "STEP_MARK_AT_ENTRY">:$step_marker_location,
DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement
DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement,
DefaultValuedAttr<BoolAttr, "false">:$use_spmd_for_xla_partitioning
);

let results = (outs);
@@ -780,9 +890,6 @@ def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> {
An op which shards the input based on the given sharding attribute.
}];

let description = [{
}];

let arguments = (ins
TF_Tensor:$input,

@@ -799,9 +906,6 @@ An op which shards the input based on the given sharding attribute.
def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> {
let summary = "Fetches multiple values from infeed as an XLA tuple.";

let description = [{
}];

let arguments = (ins
OptionalAttr<StrAttr>:$_XlaSharding
);
@@ -845,9 +949,6 @@ def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> {
Creates a dataset that batches `batch_size` elements from `input_dataset`.
}];

let description = [{
}];

let arguments = (ins
TF_VariantTensor:$input_dataset,
I64Tensor:$batch_size,
@@ -989,4 +1090,76 @@ operation create / operate on a copy of `x`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i0e function of `x` element-wise.";

let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.

This function is faster and numerically stabler than `bessel_i0(x)`.
}];

let arguments = (ins
TF_FpTensor:$x
);

let results = (outs
TF_FpTensor:$y
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i1e function of `x` element-wise.";

let description = [{
Exponentially scaled modified Bessel function of order 0 defined as
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.

This function is faster and numerically stabler than `bessel_i1(x)`.
}];

let arguments = (ins
TF_FpTensor:$x
);

let results = (outs
TF_FpTensor:$y
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}

def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> {
let summary = [{
Converts each string in the input Tensor to its hash mod by a number of buckets.
}];

let description = [{
The hash function is deterministic on the content of the string within the
process and will never change. However, it is not suitable for cryptography.
This function may be used when CPU time is scarce and inputs are trusted or
unimportant. There is a risk of adversaries constructing inputs that all hash
to the same bucket. To prevent this problem, use a strong hash function with
`tf.string_to_hash_bucket_strong`.

Examples:

>>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy()
array([0, 2, 2])
}];

let arguments = (ins
TF_StrTensor:$input,

Confined<I64Attr, [IntMinValue<1>]>:$num_buckets
);

let results = (outs
I64Tensor:$output
);
}

#endif // TF_OPS

+ 54
- 0
tools/mlir/tf_side_effects.h View File

@@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This is the side effect definition file for TensorFlow.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_

#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project

namespace mlir {
namespace TF {
namespace ResourceEffects {

struct Variable : ::mlir::SideEffects::Resource::Base<Variable>
{
StringRef getName() final
{
return "Variable";
}
};

struct Stack : ::mlir::SideEffects::Resource::Base<Stack>
{
StringRef getName() final
{
return "Stack";
}
};

struct TensorArray : ::mlir::SideEffects::Resource::Base<TensorArray>
{
StringRef getName() final
{
return "TensorArray";
}
};

} // namespace ResourceEffects
} // namespace TF
} // namespace mlir

#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_

+ 3
- 2
tools/mlir/tf_types.h View File

@@ -115,10 +115,11 @@ namespace detail {
// - `static unsigned getTypeKind()` that returns the (fixed) kind of the
// type.
template<typename Derived>
class TensorFlowTypeImpl : public Type::TypeBase<Derived, TensorFlowType>
class TensorFlowTypeImpl
: public Type::TypeBase<Derived, TensorFlowType, TypeStorage>
{
public:
using Base = typename Type::TypeBase<Derived, TensorFlowType>;
using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
using TFBase = TensorFlowTypeImpl<Derived>;
using Base::Base;



Loading…
Cancel
Save