|
|
@@ -65,6 +65,17 @@ pnnx.Output output 1 0 out |
|
|
{ |
|
|
{ |
|
|
return "F.scaled_dot_product_attention"; |
|
|
return "F.scaled_dot_product_attention"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const |
|
|
|
|
|
{ |
|
|
|
|
|
GraphRewriterPass::write(op, captured_params, captured_attrs); |
|
|
|
|
|
|
|
|
|
|
|
if (captured_params.at("scale").type == 0) |
|
|
|
|
|
{ |
|
|
|
|
|
// drop scale=None for compatiblity with old torch |
|
|
|
|
|
op->params.erase("scale"); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) |
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) |
|
|
|