|
|
|
@@ -42,4 +42,31 @@ pnnx.Output output 1 0 out |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10) |
|
|
|
|
|
|
|
class F_scaled_dot_product_attention_1 : public GraphRewriterPass |
|
|
|
{ |
|
|
|
public: |
|
|
|
const char* match_pattern_graph() const |
|
|
|
{ |
|
|
|
return R"PNNXIR(7767517 |
|
|
|
9 8 |
|
|
|
pnnx.Input input_0 0 1 query |
|
|
|
pnnx.Input input_1 0 1 key |
|
|
|
pnnx.Input input_2 0 1 value |
|
|
|
pnnx.Input input_3 0 1 attn_mask |
|
|
|
prim::Constant op_0 0 1 dropout_p value=%dropout_p |
|
|
|
prim::Constant op_1 0 1 is_causal value=%is_causal |
|
|
|
prim::Constant op_2 0 1 scale value=%scale |
|
|
|
aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out |
|
|
|
pnnx.Output output 1 0 out |
|
|
|
)PNNXIR"; |
|
|
|
} |
|
|
|
|
|
|
|
const char* type_str() const |
|
|
|
{ |
|
|
|
return "F.scaled_dot_product_attention"; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) |
|
|
|
|
|
|
|
} // namespace pnnx |