| @@ -420,6 +420,10 @@ int main(int argc, char** argv) | |||
| { | |||
| fprintf(pp, "%-16s", "InstanceNorm"); | |||
| } | |||
| else if (op == "ncnn.InstanceNormAffine") | |||
| { | |||
| fprintf(pp, "%-16s", "InstanceNorm"); | |||
| } | |||
| else if (op == "tf.AddN") | |||
| { | |||
| fprintf(pp, "%-16s", "Eltwise"); | |||
| @@ -617,6 +621,14 @@ int main(int argc, char** argv) | |||
| { | |||
| float eps = get_operation_attr_f(operation, "epsilon"); | |||
| fprintf(pp, " 0=0"); // channels | |||
| fprintf(pp, " 1=%e", eps); | |||
| fprintf(pp, " 2=0"); // affine | |||
| } | |||
| else if (op == "ncnn.InstanceNormAffine") | |||
| { | |||
| float eps = get_operation_attr_f(operation, "epsilon"); | |||
| std::string gamma_name = get_mlir_value_uniq_id(operation.getOperand(1)); | |||
| std::string beta_name = get_mlir_value_uniq_id(operation.getOperand(2)); | |||
| const mlir::Attribute& G = weights[gamma_name]; | |||
| @@ -629,6 +641,7 @@ int main(int argc, char** argv) | |||
| fprintf(pp, " 0=%d", channels); | |||
| fprintf(pp, " 1=%e", eps); | |||
| fprintf(pp, " 2=1"); // affine | |||
| fwrite(gv.data(), sizeof(float), gv.size(), bp); | |||
| fwrite(bv.data(), sizeof(float), bv.size(), bp); | |||
| @@ -36,6 +36,20 @@ class NCNN_Op<string mnemonic, list<OpTrait> traits = []> : | |||
| def NCNN_InstanceNormOp : NCNN_Op<"InstanceNorm", [NoSideEffect]> { | |||
| let arguments = (ins | |||
| F32Tensor:$x, | |||
| F32Attr:$epsilon | |||
| ); | |||
| let results = (outs | |||
| F32Tensor:$y | |||
| ); | |||
| let hasCanonicalizer = 1; | |||
| } | |||
| def NCNN_InstanceNormAffineOp : NCNN_Op<"InstanceNormAffine", [NoSideEffect]> { | |||
| let arguments = (ins | |||
| F32Tensor:$x, | |||
| F32Tensor:$gamma, | |||
| @@ -30,7 +30,13 @@ namespace ncnn { | |||
| void InstanceNormOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context) | |||
| { | |||
| results.insert<FuseInstanceNormPattern>(context); | |||
| results.insert<FuseInstanceNormPattern0>(context); | |||
| results.insert<FuseInstanceNormPattern1>(context); | |||
| } | |||
| void InstanceNormAffineOp::getCanonicalizationPatterns(OwningRewritePatternList& results, MLIRContext* context) | |||
| { | |||
| results.insert<FuseInstanceNormAffinePattern>(context); | |||
| } | |||
| } // namespace ncnn | |||
| @@ -22,7 +22,71 @@ def get_attr_f : NativeCodeCall<"$0.getValue<FloatAttr>(0)">; | |||
| def EqualOperands : Constraint<CPred<"$0 == $1">>; | |||
| def FuseInstanceNormPattern : Pat< | |||
| def FuseInstanceNormPattern0 : Pat< | |||
| (TF_MulOp | |||
| (TF_RsqrtOp | |||
| (TF_AddV2Op | |||
| (TF_MeanOp | |||
| (TF_SquaredDifferenceOp | |||
| (TF_MeanOp:$mean | |||
| $x, | |||
| (TF_ConstOp:$reduce_axis ElementsAttr), | |||
| ConstBoolAttrTrue // keep_dims | |||
| ), | |||
| $x_ | |||
| ), | |||
| $reduce_axis_, | |||
| ConstBoolAttrTrue // keep_dims | |||
| ), | |||
| (TF_ConstOp ElementsAttr:$epsilon) | |||
| ) | |||
| ), | |||
| (TF_SubOp $x__, $mean_) | |||
| ), | |||
| (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)), | |||
| [ | |||
| (EqualOperands $x, $x_), | |||
| (EqualOperands $x, $x__), | |||
| (EqualOperands $reduce_axis, $reduce_axis_), | |||
| (EqualOperands $mean, $mean_) | |||
| ] | |||
| >; | |||
| def FuseInstanceNormPattern1 : Pat< | |||
| (TF_MulOp | |||
| (TF_RsqrtOp | |||
| (TF_AddV2Op | |||
| (TF_MeanOp | |||
| (TF_SquaredDifferenceOp | |||
| $x_, | |||
| (TF_MeanOp:$mean | |||
| $x, | |||
| (TF_ConstOp:$reduce_axis ElementsAttr), | |||
| ConstBoolAttrTrue // keep_dims | |||
| ) | |||
| ), | |||
| $reduce_axis_, | |||
| ConstBoolAttrTrue // keep_dims | |||
| ), | |||
| (TF_ConstOp ElementsAttr:$epsilon) | |||
| ) | |||
| ), | |||
| (TF_SubOp $x__, $mean_) | |||
| ), | |||
| (NCNN_InstanceNormOp $x, (get_attr_f $epsilon)), | |||
| [ | |||
| (EqualOperands $x, $x_), | |||
| (EqualOperands $x, $x__), | |||
| (EqualOperands $reduce_axis, $reduce_axis_), | |||
| (EqualOperands $mean, $mean_) | |||
| ] | |||
| >; | |||
| def FuseInstanceNormAffinePattern : Pat< | |||
| (TF_ReshapeOp | |||
| (TF_AddV2Op | |||
| (TF_MulOp | |||
| @@ -56,7 +120,7 @@ def FuseInstanceNormPattern : Pat< | |||
| (TF_ConstOp ElementsAttr) | |||
| ), | |||
| (NCNN_InstanceNormOp $x, $gamma, $beta, (get_attr_f $epsilon)), | |||
| (NCNN_InstanceNormAffineOp $x, $gamma, $beta, (get_attr_f $epsilon)), | |||
| [ | |||
| (EqualOperands $reshaped, $reshaped_), | |||