|
|
|
@@ -50,7 +50,41 @@ pnnx.Output output 1 0 out |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class torch_stft_0 : public torch_stft |
|
|
|
{ |
|
|
|
public: |
|
|
|
const char* match_pattern_graph() const |
|
|
|
{ |
|
|
|
return R"PNNXIR(7767517 |
|
|
|
11 10 |
|
|
|
pnnx.Input input_0 0 1 input |
|
|
|
pnnx.Input input_1 0 1 window |
|
|
|
prim::Constant op_0 0 1 n_fft value=%n_fft |
|
|
|
prim::Constant op_1 0 1 hop_length value=%hop_length |
|
|
|
prim::Constant op_2 0 1 win_length value=%win_length |
|
|
|
prim::Constant op_3 0 1 normalized value=%normalized |
|
|
|
prim::Constant op_4 0 1 onesided value=%onesided |
|
|
|
prim::Constant op_5 0 1 return_complex value=%return_complex |
|
|
|
prim::Constant op_6 0 1 align_to_window value=%align_to_window |
|
|
|
aten::stft op_7 9 1 input n_fft hop_length win_length window normalized onesided return_complex align_to_window out |
|
|
|
pnnx.Output output 1 0 out |
|
|
|
)PNNXIR"; |
|
|
|
} |
|
|
|
|
|
|
|
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const |
|
|
|
{ |
|
|
|
torch_stft::write(op, captured_params); |
|
|
|
|
|
|
|
// keep align_to_window param only when enabled |
|
|
|
if (captured_params.at("align_to_window").type != 1 || captured_params.at("align_to_window").b == false) |
|
|
|
{ |
|
|
|
op->params.erase("align_to_window"); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft, 80) |
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_0, 80) |
|
|
|
|
|
|
|
class torch_stft_1 : public GraphRewriterPass |
|
|
|
{ |
|
|
|
|