diff --git a/tools/mlir/mlir2ncnn.cpp b/tools/mlir/mlir2ncnn.cpp index 51f744edd..20cf277d2 100644 --- a/tools/mlir/mlir2ncnn.cpp +++ b/tools/mlir/mlir2ncnn.cpp @@ -420,6 +420,10 @@ int main(int argc, char** argv) { fprintf(pp, "%-16s", "InstanceNorm"); } + else if (op == "ncnn.InstanceNormAffine") + { + fprintf(pp, "%-16s", "InstanceNorm"); + } else if (op == "tf.AddN") { fprintf(pp, "%-16s", "Eltwise"); @@ -617,6 +621,14 @@ int main(int argc, char** argv) { float eps = get_operation_attr_f(operation, "epsilon"); + fprintf(pp, " 0=0"); // channels + fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=0"); // affine + } + else if (op == "ncnn.InstanceNormAffine") + { + float eps = get_operation_attr_f(operation, "epsilon"); + std::string gamma_name = get_mlir_value_uniq_id(operation.getOperand(1)); std::string beta_name = get_mlir_value_uniq_id(operation.getOperand(2)); const mlir::Attribute& G = weights[gamma_name]; @@ -629,6 +641,7 @@ int main(int argc, char** argv) fprintf(pp, " 0=%d", channels); fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=1"); // affine fwrite(gv.data(), sizeof(float), gv.size(), bp); fwrite(bv.data(), sizeof(float), bv.size(), bp); diff --git a/tools/mlir/ncnn_ops.td b/tools/mlir/ncnn_ops.td index d71f2856c..086554488 100644 --- a/tools/mlir/ncnn_ops.td +++ b/tools/mlir/ncnn_ops.td @@ -36,6 +36,20 @@ class NCNN_Op traits = []> : def NCNN_InstanceNormOp : NCNN_Op<"InstanceNorm", [NoSideEffect]> { + let arguments = (ins + F32Tensor:$x, + F32Attr:$epsilon + ); + + let results = (outs + F32Tensor:$y + ); + + let hasCanonicalizer = 1; +} + +def NCNN_InstanceNormAffineOp : NCNN_Op<"InstanceNormAffine", [NoSideEffect]> { + let arguments = (ins F32Tensor:$x, F32Tensor:$gamma, diff --git a/tools/mlir/ncnn_rewriter.cpp b/tools/mlir/ncnn_rewriter.cpp index b733e29ce..97063b701 100644 --- a/tools/mlir/ncnn_rewriter.cpp +++ b/tools/mlir/ncnn_rewriter.cpp @@ -30,7 +30,13 @@ namespace ncnn { void InstanceNormOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + results.insert(context); + results.insert(context); +} + +void InstanceNormAffineOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context) +{ + results.insert(context); } } // namespace ncnn diff --git a/tools/mlir/ncnn_rewriter.td b/tools/mlir/ncnn_rewriter.td index c2f123d0e..a9bcdfc41 100644 --- a/tools/mlir/ncnn_rewriter.td +++ b/tools/mlir/ncnn_rewriter.td @@ -22,7 +22,71 @@ def get_attr_f : NativeCodeCall<"$0.getValue(0)">; def EqualOperands : Constraint>; -def FuseInstanceNormPattern : Pat< +def FuseInstanceNormPattern0 : Pat< + (TF_MulOp + (TF_RsqrtOp + (TF_AddV2Op + (TF_MeanOp + (TF_SquaredDifferenceOp + (TF_MeanOp:$mean + $x, + (TF_ConstOp:$reduce_axis ElementsAttr), + ConstBoolAttrTrue // keep_dims + ), + $x_ + ), + $reduce_axis_, + ConstBoolAttrTrue // keep_dims + ), + (TF_ConstOp ElementsAttr:$epsilon) + ) + ), + (TF_SubOp $x__, $mean_) + ), + + (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)), + + [ + (EqualOperands $x, $x_), + (EqualOperands $x, $x__), + (EqualOperands $reduce_axis, $reduce_axis_), + (EqualOperands $mean, $mean_) + ] +>; + +def FuseInstanceNormPattern1 : Pat< + (TF_MulOp + (TF_RsqrtOp + (TF_AddV2Op + (TF_MeanOp + (TF_SquaredDifferenceOp + $x_, + (TF_MeanOp:$mean + $x, + (TF_ConstOp:$reduce_axis ElementsAttr), + ConstBoolAttrTrue // keep_dims + ) + ), + $reduce_axis_, + ConstBoolAttrTrue // keep_dims + ), + (TF_ConstOp ElementsAttr:$epsilon) + ) + ), + (TF_SubOp $x__, $mean_) + ), + + (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)), + + [ + (EqualOperands $x, $x_), + (EqualOperands $x, $x__), + (EqualOperands $reduce_axis, $reduce_axis_), + (EqualOperands $mean, $mean_) + ] +>; + +def FuseInstanceNormAffinePattern : Pat< (TF_ReshapeOp (TF_AddV2Op (TF_MulOp @@ -56,7 +120,7 @@ def FuseInstanceNormPattern : Pat< (TF_ConstOp ElementsAttr) ), - (NCNN_InstanceNormOp $x, $gamma, $beta, (get_attr_f $epsilon)), + (NCNN_InstanceNormAffineOp $x, $gamma, $beta, (get_attr_f $epsilon)), [ (EqualOperands $reshaped, $reshaped_),