diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 03841989b..a50e1e5b7 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -19,10 +19,10 @@ concurrency: variables: protobuf_version: 21.12 - libtorch_version: 2.6.0 - libtorchvision_version: 0.21.0 - onnxruntime_version: 1.21.0 - cache_date: 20250402 + libtorch_version: 2.7.0 + libtorchvision_version: 0.22.0 + onnxruntime_version: 1.21.1 + cache_date: 20250423 jobs: ubuntu: @@ -81,6 +81,10 @@ jobs: torchvision-version: 0.21.0 torchaudio-version: '2.6.0+cpu' + - torch-version: 2.7.0 + torchvision-version: 0.22.0 + torchaudio-version: '2.7.0+cpu' + runs-on: pool-name: docker container: diff --git a/tools/pnnx/src/pass_level2/torch_stft.cpp b/tools/pnnx/src/pass_level2/torch_stft.cpp index fa5a59701..b0d9c1cf3 100644 --- a/tools/pnnx/src/pass_level2/torch_stft.cpp +++ b/tools/pnnx/src/pass_level2/torch_stft.cpp @@ -50,7 +50,41 @@ pnnx.Output output 1 0 out } }; +class torch_stft_0 : public torch_stft +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +11 10 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 window +prim::Constant op_0 0 1 n_fft value=%n_fft +prim::Constant op_1 0 1 hop_length value=%hop_length +prim::Constant op_2 0 1 win_length value=%win_length +prim::Constant op_3 0 1 normalized value=%normalized +prim::Constant op_4 0 1 onesided value=%onesided +prim::Constant op_5 0 1 return_complex value=%return_complex +prim::Constant op_6 0 1 align_to_window value=%align_to_window +aten::stft op_7 9 1 input n_fft hop_length win_length window normalized onesided return_complex align_to_window out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + torch_stft::write(op, captured_params); + + // keep align_to_window param only when enabled + if (captured_params.at("align_to_window").type != 1 || captured_params.at("align_to_window").b == false) + { + op->params.erase("align_to_window"); + } + } +}; + REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft, 80) +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_0, 80) class torch_stft_1 : public GraphRewriterPass { diff --git a/tools/pnnx/tests/onnx/test_F_relu.py b/tools/pnnx/tests/onnx/test_F_relu.py index e0c979935..3f75e256d 100644 --- a/tools/pnnx/tests/onnx/test_F_relu.py +++ b/tools/pnnx/tests/onnx/test_F_relu.py @@ -59,7 +59,7 @@ def test(): if not torch.allclose(a0, b0, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_convnext_tiny.py b/tools/pnnx/tests/onnx/test_convnext_tiny.py index 10e07badc..2b89fa52a 100644 --- a/tools/pnnx/tests/onnx/test_convnext_tiny.py +++ b/tools/pnnx/tests/onnx/test_convnext_tiny.py @@ -43,7 +43,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_mobilenet_v2.py b/tools/pnnx/tests/onnx/test_mobilenet_v2.py index 348adfea9..27a904120 100644 --- a/tools/pnnx/tests/onnx/test_mobilenet_v2.py +++ b/tools/pnnx/tests/onnx/test_mobilenet_v2.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py b/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py index c2ec5eb3f..301215eba 100644 --- a/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py +++ b/tools/pnnx/tests/onnx/test_mobilenet_v3_small.py @@ -42,7 +42,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_nn_ReLU.py b/tools/pnnx/tests/onnx/test_nn_ReLU.py index 374537031..f5699e4cd 100644 --- a/tools/pnnx/tests/onnx/test_nn_ReLU.py +++ b/tools/pnnx/tests/onnx/test_nn_ReLU.py @@ -61,7 +61,7 @@ def test(): if not torch.allclose(a0, b0, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_resnet18.py b/tools/pnnx/tests/onnx/test_resnet18.py index 2571ffe81..ce9c3b334 100644 --- a/tools/pnnx/tests/onnx/test_resnet18.py +++ b/tools/pnnx/tests/onnx/test_resnet18.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py b/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py index 371773d69..d753ad279 100644 --- a/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py +++ b/tools/pnnx/tests/onnx/test_shufflenet_v2_x1_0.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_squeezenet1_1.py b/tools/pnnx/tests/onnx/test_squeezenet1_1.py index 7a1b3f0e2..2cfc2524f 100644 --- a/tools/pnnx/tests/onnx/test_squeezenet1_1.py +++ b/tools/pnnx/tests/onnx/test_squeezenet1_1.py @@ -39,7 +39,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_swin_t.py b/tools/pnnx/tests/onnx/test_swin_t.py index 471c40cb9..e598f79ee 100644 --- a/tools/pnnx/tests/onnx/test_swin_t.py +++ b/tools/pnnx/tests/onnx/test_swin_t.py @@ -43,7 +43,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx diff --git a/tools/pnnx/tests/onnx/test_vit_b_32.py b/tools/pnnx/tests/onnx/test_vit_b_32.py index 61f3dd7f1..f79e116f1 100644 --- a/tools/pnnx/tests/onnx/test_vit_b_32.py +++ b/tools/pnnx/tests/onnx/test_vit_b_32.py @@ -46,7 +46,7 @@ def test(): if not torch.allclose(a, b, 1e-4, 1e-4): return False - if version.parse(torch.__version__) < version.parse('2.7'): + if version.parse(torch.__version__) < version.parse('2.8'): return True # export dynamo onnx