Browse Source

pnnx fuse wav2vec style mha (#6004)

tags/20250428
nihui GitHub 1 year ago
parent
commit
3d16b657bc
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 31 additions and 0 deletions
  1. +31
    -0
      tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp

+ 31
- 0
tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp View File

@@ -1162,6 +1162,35 @@ pnnx.Output output 1 0 out
}
};

class fuse_multiheadattention_pass_12_2 : public fuse_multiheadattention_pass_12
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
18 17
pnnx.Input input_0 0 1 input
nn.Linear op_0 1 1 input 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input 15 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input 16 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
Tensor.view op_3 1 1 14 17 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_4 1 1 15 18 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 16 19 shape=(%batch,%size,%num_heads,%feat_per_head)
torch.transpose op_6 1 1 19 20 dim0=1 dim1=2
torch.transpose op_7 1 1 18 21 dim0=1 dim1=2
torch.transpose op_8 1 1 17 22 dim0=1 dim1=2
Tensor.contiguous op_9 1 1 20 201 memory_format=*
Tensor.contiguous op_10 1 1 21 211 memory_format=*
Tensor.contiguous op_11 1 1 22 221 memory_format=*
F.scaled_dot_product_attention op_12 3 1 221 211 201 23 attn_mask=None dropout_p=0.000000e+00 is_causal=False
torch.transpose op_13 1 1 23 24 dim0=1 dim1=2
Tensor.reshape op_14 1 1 24 25 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_multiheadattention_pass_13 : public fuse_multiheadattention_pass_qkv
{
public:
@@ -2145,6 +2174,7 @@ void fuse_multiheadattention(Graph& graph)
fuse_multiheadattention_pass_10 j;
fuse_multiheadattention_pass_12 k;
fuse_multiheadattention_pass_12_1 k1;
fuse_multiheadattention_pass_12_2 k2;
fuse_multiheadattention_pass_13 l;
fuse_multiheadattention_pass_14 m;
fuse_multiheadattention_pass_15 n;
@@ -2186,6 +2216,7 @@ void fuse_multiheadattention(Graph& graph)
pnnx_graph_rewrite(graph, &j, opindex);
pnnx_graph_rewrite(graph, &k, opindex);
pnnx_graph_rewrite(graph, &k1, opindex);
pnnx_graph_rewrite(graph, &k2, opindex);
pnnx_graph_rewrite(graph, &l, opindex);
pnnx_graph_rewrite(graph, &m, opindex);
pnnx_graph_rewrite(graph, &n, opindex);


Loading…
Cancel
Save