Browse Source

mlir fuse instancenorm without affine

tags/20200916
nihuini 5 years ago
parent
commit
d279e1f514
4 changed files with 100 additions and 3 deletions
  1. +13
    -0
      tools/mlir/mlir2ncnn.cpp
  2. +14
    -0
      tools/mlir/ncnn_ops.td
  3. +7
    -1
      tools/mlir/ncnn_rewriter.cpp
  4. +66
    -2
      tools/mlir/ncnn_rewriter.td

+ 13
- 0
tools/mlir/mlir2ncnn.cpp View File

@@ -420,6 +420,10 @@ int main(int argc, char** argv)
{
fprintf(pp, "%-16s", "InstanceNorm");
}
else if (op == "ncnn.InstanceNormAffine")
{
fprintf(pp, "%-16s", "InstanceNorm");
}
else if (op == "tf.AddN")
{
fprintf(pp, "%-16s", "Eltwise");
@@ -617,6 +621,14 @@ int main(int argc, char** argv)
{
float eps = get_operation_attr_f(operation, "epsilon");

fprintf(pp, " 0=0"); // channels
fprintf(pp, " 1=%e", eps);
fprintf(pp, " 2=0"); // affine
}
else if (op == "ncnn.InstanceNormAffine")
{
float eps = get_operation_attr_f(operation, "epsilon");

std::string gamma_name = get_mlir_value_uniq_id(operation.getOperand(1));
std::string beta_name = get_mlir_value_uniq_id(operation.getOperand(2));
const mlir::Attribute& G = weights[gamma_name];
@@ -629,6 +641,7 @@ int main(int argc, char** argv)

fprintf(pp, " 0=%d", channels);
fprintf(pp, " 1=%e", eps);
fprintf(pp, " 2=1"); // affine

fwrite(gv.data(), sizeof(float), gv.size(), bp);
fwrite(bv.data(), sizeof(float), bv.size(), bp);


+ 14
- 0
tools/mlir/ncnn_ops.td View File

@@ -36,6 +36,20 @@ class NCNN_Op<string mnemonic, list<OpTrait> traits = []> :

def NCNN_InstanceNormOp : NCNN_Op<"InstanceNorm", [NoSideEffect]> {

let arguments = (ins
F32Tensor:$x,
F32Attr:$epsilon
);

let results = (outs
F32Tensor:$y
);

let hasCanonicalizer = 1;
}

def NCNN_InstanceNormAffineOp : NCNN_Op<"InstanceNormAffine", [NoSideEffect]> {

let arguments = (ins
F32Tensor:$x,
F32Tensor:$gamma,


+ 7
- 1
tools/mlir/ncnn_rewriter.cpp View File

@@ -30,7 +30,13 @@ namespace ncnn {

void InstanceNormOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context)
{
results.insert<FuseInstanceNormPattern>(context);
results.insert<FuseInstanceNormPattern0>(context);
results.insert<FuseInstanceNormPattern1>(context);
}

void InstanceNormAffineOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context)
{
results.insert<FuseInstanceNormAffinePattern>(context);
}

} // namespace ncnn


+ 66
- 2
tools/mlir/ncnn_rewriter.td View File

@@ -22,7 +22,71 @@ def get_attr_f : NativeCodeCall<"$0.getValue<FloatAttr>(0)">;

def EqualOperands : Constraint<CPred<"$0 == $1">>;

def FuseInstanceNormPattern : Pat<
def FuseInstanceNormPattern0 : Pat<
(TF_MulOp
(TF_RsqrtOp
(TF_AddV2Op
(TF_MeanOp
(TF_SquaredDifferenceOp
(TF_MeanOp:$mean
$x,
(TF_ConstOp:$reduce_axis ElementsAttr),
ConstBoolAttrTrue // keep_dims
),
$x_
),
$reduce_axis_,
ConstBoolAttrTrue // keep_dims
),
(TF_ConstOp ElementsAttr:$epsilon)
)
),
(TF_SubOp $x__, $mean_)
),

(NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),

[
(EqualOperands $x, $x_),
(EqualOperands $x, $x__),
(EqualOperands $reduce_axis, $reduce_axis_),
(EqualOperands $mean, $mean_)
]
>;

def FuseInstanceNormPattern1 : Pat<
(TF_MulOp
(TF_RsqrtOp
(TF_AddV2Op
(TF_MeanOp
(TF_SquaredDifferenceOp
$x_,
(TF_MeanOp:$mean
$x,
(TF_ConstOp:$reduce_axis ElementsAttr),
ConstBoolAttrTrue // keep_dims
)
),
$reduce_axis_,
ConstBoolAttrTrue // keep_dims
),
(TF_ConstOp ElementsAttr:$epsilon)
)
),
(TF_SubOp $x__, $mean_)
),

(NCNN_InstanceNormOp $x, (get_attr_f $epsilon)),

[
(EqualOperands $x, $x_),
(EqualOperands $x, $x__),
(EqualOperands $reduce_axis, $reduce_axis_),
(EqualOperands $mean, $mean_)
]
>;

def FuseInstanceNormAffinePattern : Pat<
(TF_ReshapeOp
(TF_AddV2Op
(TF_MulOp
@@ -56,7 +120,7 @@ def FuseInstanceNormPattern : Pat<
(TF_ConstOp ElementsAttr)
),

(NCNN_InstanceNormOp $x, $gamma, $beta, (get_attr_f $epsilon)),
(NCNN_InstanceNormAffineOp $x, $gamma, $beta, (get_attr_f $epsilon)),

[
(EqualOperands $reshaped, $reshaped_),


Loading…
Cancel
Save