diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp index c178788f2..5ac689704 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -1162,6 +1162,35 @@ pnnx.Output output 1 0 out } }; +class fuse_multiheadattention_pass_12_2 : public fuse_multiheadattention_pass_12 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +18 17 +pnnx.Input input_0 0 1 input +nn.Linear op_0 1 1 input 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 input 15 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 input 16 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.view op_3 1 1 14 17 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_4 1 1 15 18 shape=(%batch,%size,%num_heads,%feat_per_head) +Tensor.view op_5 1 1 16 19 shape=(%batch,%size,%num_heads,%feat_per_head) +torch.transpose op_6 1 1 19 20 dim0=1 dim1=2 +torch.transpose op_7 1 1 18 21 dim0=1 dim1=2 +torch.transpose op_8 1 1 17 22 dim0=1 dim1=2 +Tensor.contiguous op_9 1 1 20 201 memory_format=* +Tensor.contiguous op_10 1 1 21 211 memory_format=* +Tensor.contiguous op_11 1 1 22 221 memory_format=* +F.scaled_dot_product_attention op_12 3 1 221 211 201 23 attn_mask=None dropout_p=0.000000e+00 is_causal=False +torch.transpose op_13 1 1 23 24 dim0=1 dim1=2 +Tensor.reshape op_14 1 1 24 25 shape=(%batch,%size,%embed_dim) +nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + class fuse_multiheadattention_pass_13 : public fuse_multiheadattention_pass_qkv { public: @@ -2145,6 +2174,7 @@ void fuse_multiheadattention(Graph& graph) fuse_multiheadattention_pass_10 j; fuse_multiheadattention_pass_12 k; fuse_multiheadattention_pass_12_1 k1; + fuse_multiheadattention_pass_12_2 k2; fuse_multiheadattention_pass_13 l; fuse_multiheadattention_pass_14 m; fuse_multiheadattention_pass_15 n; @@ -2186,6 +2216,7 @@ void fuse_multiheadattention(Graph& graph) pnnx_graph_rewrite(graph, &j, opindex); pnnx_graph_rewrite(graph, &k, opindex); pnnx_graph_rewrite(graph, &k1, opindex); + pnnx_graph_rewrite(graph, &k2, opindex); pnnx_graph_rewrite(graph, &l, opindex); pnnx_graph_rewrite(graph, &m, opindex); pnnx_graph_rewrite(graph, &n, opindex);