diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index c87794a0f..56f7a3fc2 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -360,6 +360,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_multiheadattention.cpp pass_level5/fuse_scaled_dot_product_attention.cpp pass_level5/fuse_select_to_unbind.cpp + pass_level5/fuse_silu.cpp pass_level5/fuse_slice_copy.cpp pass_level5/fuse_slice_indices.cpp pass_level5/fuse_slice_to_tensor_split.cpp @@ -372,7 +373,6 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_static_instancenorm.cpp pass_level5/fuse_static_layernorm.cpp pass_level5/fuse_static_linear.cpp - pass_level5/fuse_swish.cpp pass_level5/normalize_einsum_equation.cpp pass_level5/unroll_rnn_op.cpp ) diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp index 83efcdc4e..e942878be 100644 --- a/tools/pnnx/src/pass_level2/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -100,11 +100,19 @@ pnnx.Output output 1 0 out { return "F.batch_norm"; } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(op, captured_params, captured_attrs); + + std::swap(op->inputs[1], op->inputs[3]); + std::swap(op->inputs[2], op->inputs[4]); + } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx, 10) -class F_batch_norm_onnx_1 : public GraphRewriterPass +class F_batch_norm_onnx_1 : public F_batch_norm_onnx { public: const char* match_pattern_graph() const @@ -120,11 +128,6 @@ BatchNormalization op_0 5 1 input weight bias running_mean running_v pnnx.Output output 1 0 out )PNNXIR"; } - - const char* type_str() const - { - return "F.batch_norm"; - } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx_1, 10) diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 4a1a2760a..4d483267d 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -46,6 +46,7 @@ #include "pass_level5/fuse_pad_conv2d.h" #include "pass_level5/fuse_scaled_dot_product_attention.h" #include "pass_level5/fuse_select_to_unbind.h" +#include "pass_level5/fuse_silu.h" #include "pass_level5/fuse_slice_copy.h" #include "pass_level5/fuse_slice_indices.h" #include "pass_level5/fuse_slice_to_tensor_split.h" @@ -58,7 +59,6 @@ #include "pass_level5/fuse_static_instancenorm.h" #include "pass_level5/fuse_static_layernorm.h" #include "pass_level5/fuse_static_linear.h" -#include "pass_level5/fuse_swish.h" #include "pass_level5/normalize_einsum_equation.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" @@ -144,7 +144,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_multiheadattention(g); fuse_scaled_dot_product_attention(g); - fuse_swish(g); + fuse_silu(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level5/fuse_swish.cpp b/tools/pnnx/src/pass_level5/fuse_silu.cpp similarity index 86% rename from tools/pnnx/src/pass_level5/fuse_swish.cpp rename to tools/pnnx/src/pass_level5/fuse_silu.cpp index dc038de3d..c7d7e10eb 100644 --- a/tools/pnnx/src/pass_level5/fuse_swish.cpp +++ b/tools/pnnx/src/pass_level5/fuse_silu.cpp @@ -12,13 +12,13 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "fuse_swish.h" +#include "fuse_silu.h" #include "pass_level2.h" namespace pnnx { -class fuse_swish_pass : public GraphRewriterPass +class fuse_silu_pass : public GraphRewriterPass { public: const char* match_pattern_graph() const @@ -34,16 +34,16 @@ pnnx.Output output 1 0 out const char* type_str() const { - return "F.swish"; + return "F.silu"; } const char* name_str() const { - return "swish"; + return "silu"; } }; -class fuse_swish_pass_1 : public fuse_swish_pass +class fuse_silu_pass_1 : public fuse_silu_pass { public: const char* match_pattern_graph() const @@ -58,10 +58,10 @@ pnnx.Output output 1 0 out } }; -void fuse_swish(Graph& graph) +void fuse_silu(Graph& graph) { - fuse_swish_pass a; - fuse_swish_pass_1 b; + fuse_silu_pass a; + fuse_silu_pass_1 b; int opindex = 0; pnnx_graph_rewrite(graph, &a, opindex); diff --git a/tools/pnnx/src/pass_level5/fuse_swish.h b/tools/pnnx/src/pass_level5/fuse_silu.h similarity index 96% rename from tools/pnnx/src/pass_level5/fuse_swish.h rename to tools/pnnx/src/pass_level5/fuse_silu.h index 7c31d6f98..dde78e60e 100644 --- a/tools/pnnx/src/pass_level5/fuse_swish.h +++ b/tools/pnnx/src/pass_level5/fuse_silu.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_swish(Graph& graph); +void fuse_silu(Graph& graph); } // namespace pnnx