|
|
|
@@ -1162,6 +1162,35 @@ pnnx.Output output 1 0 out |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class fuse_multiheadattention_pass_12_2 : public fuse_multiheadattention_pass_12 |
|
|
|
{ |
|
|
|
public: |
|
|
|
const char* match_pattern_graph() const |
|
|
|
{ |
|
|
|
return R"PNNXIR(7767517 |
|
|
|
18 17 |
|
|
|
pnnx.Input input_0 0 1 input |
|
|
|
nn.Linear op_0 1 1 input 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight |
|
|
|
nn.Linear op_1 1 1 input 15 bias=%kbias in_features=%embed_dim out_features=%embed_dim @bias @weight |
|
|
|
nn.Linear op_2 1 1 input 16 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight |
|
|
|
Tensor.view op_3 1 1 14 17 shape=(%batch,%size,%num_heads,%feat_per_head) |
|
|
|
Tensor.view op_4 1 1 15 18 shape=(%batch,%size,%num_heads,%feat_per_head) |
|
|
|
Tensor.view op_5 1 1 16 19 shape=(%batch,%size,%num_heads,%feat_per_head) |
|
|
|
torch.transpose op_6 1 1 19 20 dim0=1 dim1=2 |
|
|
|
torch.transpose op_7 1 1 18 21 dim0=1 dim1=2 |
|
|
|
torch.transpose op_8 1 1 17 22 dim0=1 dim1=2 |
|
|
|
Tensor.contiguous op_9 1 1 20 201 memory_format=* |
|
|
|
Tensor.contiguous op_10 1 1 21 211 memory_format=* |
|
|
|
Tensor.contiguous op_11 1 1 22 221 memory_format=* |
|
|
|
F.scaled_dot_product_attention op_12 3 1 221 211 201 23 attn_mask=None dropout_p=0.000000e+00 is_causal=False |
|
|
|
torch.transpose op_13 1 1 23 24 dim0=1 dim1=2 |
|
|
|
Tensor.reshape op_14 1 1 24 25 shape=(%batch,%size,%embed_dim) |
|
|
|
nn.Linear out_proj 1 1 25 out bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight |
|
|
|
pnnx.Output output 1 0 out |
|
|
|
)PNNXIR"; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class fuse_multiheadattention_pass_13 : public fuse_multiheadattention_pass_qkv |
|
|
|
{ |
|
|
|
public: |
|
|
|
@@ -2145,6 +2174,7 @@ void fuse_multiheadattention(Graph& graph) |
|
|
|
fuse_multiheadattention_pass_10 j; |
|
|
|
fuse_multiheadattention_pass_12 k; |
|
|
|
fuse_multiheadattention_pass_12_1 k1; |
|
|
|
fuse_multiheadattention_pass_12_2 k2; |
|
|
|
fuse_multiheadattention_pass_13 l; |
|
|
|
fuse_multiheadattention_pass_14 m; |
|
|
|
fuse_multiheadattention_pass_15 n; |
|
|
|
@@ -2186,6 +2216,7 @@ void fuse_multiheadattention(Graph& graph) |
|
|
|
pnnx_graph_rewrite(graph, &j, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &k, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &k1, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &k2, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &l, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &m, opindex); |
|
|
|
pnnx_graph_rewrite(graph, &n, opindex); |
|
|
|
|