* only supports hann, hamming and all-one window * inverse spectrogram does not support length parameter * spectrogram always returns torch.view_as_real(out) as ncnn does not support complex typed mat yet * inverse spectrogram always accepts torch.view_as_complex(in) as ncnn does not support complex typed mat yettags/20241226
| @@ -31,39 +31,51 @@ jobs: | |||
| include: | |||
| - torch-version: 1.8.1 | |||
| torchvision-version: 0.9.1 | |||
| torchaudio-version: 0.8.1 | |||
| - torch-version: 1.9.1 | |||
| torchvision-version: 0.10.1 | |||
| torchaudio-version: 0.9.1 | |||
| - torch-version: 1.10.0 | |||
| torchvision-version: 0.11.1 | |||
| torchaudio-version: '0.10.0+cpu' | |||
| - torch-version: 1.11.0 | |||
| torchvision-version: 0.12.0 | |||
| torchaudio-version: '0.11.0+cpu' | |||
| - torch-version: 1.12.0 | |||
| torchvision-version: 0.13.0 | |||
| torchaudio-version: '0.12.0+cpu' | |||
| - torch-version: 1.13.0 | |||
| torchvision-version: 0.14.0 | |||
| torchaudio-version: '0.13.0+cpu' | |||
| - torch-version: 2.0.0 | |||
| torchvision-version: 0.15.1 | |||
| torchaudio-version: '2.0.0+cpu' | |||
| - torch-version: 2.1.0 | |||
| torchvision-version: 0.16.0 | |||
| torchaudio-version: '2.1.0+cpu' | |||
| - torch-version: 2.2.1 | |||
| torchvision-version: 0.17.1 | |||
| torchaudio-version: '2.2.1+cpu' | |||
| - torch-version: 2.3.0 | |||
| torchvision-version: 0.18.0 | |||
| torchaudio-version: '2.3.0+cpu' | |||
| - torch-version: 2.4.0 | |||
| torchvision-version: 0.19.0 | |||
| torchaudio-version: '2.4.0+cpu' | |||
| - torch-version: 2.5.0 | |||
| torchvision-version: 0.20.0 | |||
| torchaudio-version: '2.5.0+cpu' | |||
| runs-on: | |||
| pool-name: docker | |||
| @@ -169,7 +181,7 @@ jobs: | |||
| - name: setup-pytorch | |||
| run: | | |||
| export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}} | |||
| pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu --index-url https://download.pytorch.org/whl/cpu | |||
| pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} --index-url https://download.pytorch.org/whl/cpu | |||
| pip3 install --user onnx | |||
| pip3 install --user onnxscript | |||
| @@ -46,6 +46,7 @@ | |||
| * [Input](#input) | |||
| * [InstanceNorm](#instancenorm) | |||
| * [Interp](#interp) | |||
| * [InverseSpectrogram](#inversespectrogram) | |||
| * [LayerNorm](#layernorm) | |||
| * [Log](#log) | |||
| * [LRN](#lrn) | |||
| @@ -81,6 +82,7 @@ | |||
| * [Slice](#slice) | |||
| * [Softmax](#softmax) | |||
| * [Softplus](#softplus) | |||
| * [Spectrogram](#spectrogram) | |||
| * [Split](#split) | |||
| * [Swish](#swish) | |||
| * [TanH](#tanh) | |||
| @@ -1141,6 +1143,30 @@ Resize type: | |||
| - 2 = Bilinear | |||
| - 3 = Bicubic | |||
| # InverseSpectrogram | |||
| ``` | |||
| x1 = x as complex | |||
| x1 = x1 * sqrt(norm) if normalized | |||
| y = istft(x1) | |||
| y1 = unpad(y) if center | |||
| if returns == 0 return y1 as complex | |||
| if returns == 1 return y1 real | |||
| if returns == 2 return y1 imag | |||
| ``` | |||
| * one_blob_only | |||
| | param id | name | type | default | description | | |||
| | --------- | ------------- | ----- | --------- | ----------------- | | |||
| | 0 | n_fft | int | 0 | | | |||
| | 1 | returns | int | 1 | | | |||
| | 2 | hoplen | int | n_fft / 4 | | | |||
| | 3 | winlen | int | n_fft | | | |||
| | 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming | | |||
| | 5 | center | int | 1 | | | |||
| | 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy | | |||
| # LayerNorm | |||
| ``` | |||
| split x along outmost axis into part x0, x1 ... | |||
| @@ -1829,6 +1855,31 @@ y = log(exp(x) + 1) | |||
| * one_blob_only | |||
| * support_inplace | |||
| # Spectrogram | |||
| ``` | |||
| x1 = pad(x) if center | |||
| y = stft(x1) | |||
| y = y / sqrt(norm) if normalized | |||
| if power == 0 return y as real | |||
| if power == 1 return magnitude | |||
| if power == 2 return square of magnitude | |||
| ``` | |||
| * one_blob_only | |||
| | param id | name | type | default | description | | |||
| | --------- | ------------- | ----- | --------- | ----------------- | | |||
| | 0 | n_fft | int | 0 | | | |||
| | 1 | power | int | 0 | | | |||
| | 2 | hoplen | int | n_fft / 4 | | | |||
| | 3 | winlen | int | n_fft | | | |||
| | 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming | | |||
| | 5 | center | int | 1 | | | |||
| | 6 | pad_type | int | 2 | 0=CONSTANT 1=REPLICATE 2=REFLECT | | |||
| | 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy | | |||
| | 8 | onesided | int | 1 | | | |||
| # Split | |||
| ``` | |||
| y0, y1 ... = x | |||
| @@ -167,6 +167,8 @@ ncnn_add_layer(Diag) | |||
| ncnn_add_layer(CELU) | |||
| ncnn_add_layer(Shrink) | |||
| ncnn_add_layer(RMSNorm) | |||
| ncnn_add_layer(Spectrogram) | |||
| ncnn_add_layer(InverseSpectrogram) | |||
| if(NCNN_VULKAN) | |||
| ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) | |||
| @@ -0,0 +1,238 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "inversespectrogram.h" | |||
| namespace ncnn { | |||
| InverseSpectrogram::InverseSpectrogram() | |||
| { | |||
| one_blob_only = true; | |||
| support_inplace = false; | |||
| } | |||
| int InverseSpectrogram::load_param(const ParamDict& pd) | |||
| { | |||
| n_fft = pd.get(0, 0); | |||
| returns = pd.get(1, 0); | |||
| hoplen = pd.get(2, n_fft / 4); | |||
| winlen = pd.get(3, n_fft); | |||
| window_type = pd.get(4, 0); | |||
| center = pd.get(5, 1); | |||
| normalized = pd.get(7, 0); | |||
| // assert winlen <= n_fft | |||
| // generate window | |||
| window_data.create(normalized == 2 ? n_fft + 1 : n_fft); | |||
| { | |||
| float* p = window_data; | |||
| for (int i = 0; i < (n_fft - winlen) / 2; i++) | |||
| { | |||
| *p++ = 0.f; | |||
| } | |||
| if (window_type == 0) | |||
| { | |||
| // all ones | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 1.f; | |||
| } | |||
| } | |||
| if (window_type == 1) | |||
| { | |||
| // hann window | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen)); | |||
| } | |||
| } | |||
| if (window_type == 2) | |||
| { | |||
| // hamming window | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen); | |||
| } | |||
| } | |||
| for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++) | |||
| { | |||
| *p++ = 0.f; | |||
| } | |||
| // pre-calculated window norm factor | |||
| if (normalized == 2) | |||
| { | |||
| float sqsum = 0.f; | |||
| for (int i = 0; i < n_fft; i++) | |||
| { | |||
| sqsum += window_data[i] * window_data[i]; | |||
| } | |||
| window_data[n_fft] = sqrt(sqsum); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const | |||
| { | |||
| // https://github.com/librosa/librosa/blob/main/librosa/core/spectrum.py#L630 | |||
| // TODO custom window | |||
| // TODO output length | |||
| const int frames = bottom_blob.h; | |||
| const int freqs = bottom_blob.c; | |||
| // assert freqs == n_fft or freqs == n_fft / 2 + 1 | |||
| const int onesided = freqs == n_fft / 2 + 1 ? 1 : 0; | |||
| const int outsize = center ? (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2) : (frames - 1) * hoplen + n_fft; | |||
| const size_t elemsize = bottom_blob.elemsize; | |||
| if (returns == 0) | |||
| { | |||
| top_blob.create(2, outsize, elemsize, opt.blob_allocator); | |||
| } | |||
| else | |||
| { | |||
| top_blob.create(outsize, elemsize, opt.blob_allocator); | |||
| } | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| Mat window_sumsquare(outsize + n_fft, elemsize, opt.workspace_allocator); | |||
| if (window_sumsquare.empty()) | |||
| return -100; | |||
| top_blob.fill(0.f); | |||
| window_sumsquare.fill(0.f); | |||
| for (int j = 0; j < frames; j++) | |||
| { | |||
| // collect complex | |||
| Mat sp(2, n_fft); | |||
| if (onesided == 1) | |||
| { | |||
| for (int k = 0; k < n_fft / 2 + 1; k++) | |||
| { | |||
| sp.row(k)[0] = bottom_blob.channel(k).row(j)[0]; | |||
| sp.row(k)[1] = bottom_blob.channel(k).row(j)[1]; | |||
| } | |||
| for (int k = n_fft / 2 + 1; k < n_fft; k++) | |||
| { | |||
| sp.row(k)[0] = bottom_blob.channel(n_fft - k).row(j)[0]; | |||
| sp.row(k)[1] = -bottom_blob.channel(n_fft - k).row(j)[1]; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| for (int k = 0; k < n_fft; k++) | |||
| { | |||
| sp.row(k)[0] = bottom_blob.channel(k).row(j)[0]; | |||
| sp.row(k)[1] = bottom_blob.channel(k).row(j)[1]; | |||
| } | |||
| } | |||
| if (normalized == 1) | |||
| { | |||
| float norm = sqrt(n_fft); | |||
| for (int i = 0; i < 2 * n_fft; i++) | |||
| { | |||
| sp[i] *= norm; | |||
| } | |||
| } | |||
| if (normalized == 2) | |||
| { | |||
| float norm = window_data[n_fft]; | |||
| for (int i = 0; i < 2 * n_fft; i++) | |||
| { | |||
| sp[i] *= norm; | |||
| } | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int i = 0; i < n_fft; i++) | |||
| { | |||
| // inverse dft | |||
| float re = 0.f; | |||
| float im = 0.f; | |||
| for (int k = 0; k < n_fft; k++) | |||
| { | |||
| double angle = 2 * 3.14159265358979323846 * i * k / n_fft; | |||
| re += sp.row(k)[0] * cosf(angle) - sp.row(k)[1] * sinf(angle); | |||
| im += sp.row(k)[0] * sinf(angle) + sp.row(k)[1] * cosf(angle); | |||
| } | |||
| re /= n_fft; | |||
| im /= n_fft; | |||
| // apply window | |||
| re *= window_data[i]; | |||
| im *= window_data[i]; | |||
| int output_index = j * hoplen + i; | |||
| if (center == 1) | |||
| { | |||
| output_index -= n_fft / 2; | |||
| } | |||
| if (output_index >= 0 && output_index < outsize) | |||
| { | |||
| // square window | |||
| window_sumsquare[output_index] += window_data[i] * window_data[i]; | |||
| if (returns == 0) | |||
| { | |||
| top_blob.row(output_index)[0] += re; | |||
| top_blob.row(output_index)[1] += im; | |||
| } | |||
| if (returns == 1) | |||
| { | |||
| top_blob[output_index] += re; | |||
| } | |||
| if (returns == 2) | |||
| { | |||
| top_blob[output_index] += im; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // square window norm | |||
| if (returns == 0) | |||
| { | |||
| for (int i = 0; i < outsize; i++) | |||
| { | |||
| if (window_sumsquare[i] != 0.f) | |||
| { | |||
| top_blob.row(i)[0] /= window_sumsquare[i]; | |||
| top_blob.row(i)[1] /= window_sumsquare[i]; | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| for (int i = 0; i < outsize; i++) | |||
| { | |||
| if (window_sumsquare[i] != 0.f) | |||
| top_blob[i] /= window_sumsquare[i]; | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef LAYER_INVERSESPECTROGRAM_H | |||
| #define LAYER_INVERSESPECTROGRAM_H | |||
| #include "layer.h" | |||
| namespace ncnn { | |||
| class InverseSpectrogram : public Layer | |||
| { | |||
| public: | |||
| InverseSpectrogram(); | |||
| virtual int load_param(const ParamDict& pd); | |||
| virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; | |||
| public: | |||
| int n_fft; | |||
| int returns; // 0=complex 1=real 2=imag | |||
| int hoplen; | |||
| int winlen; | |||
| int window_type; // 0=ones 1=hann 2=hamming | |||
| int center; | |||
| int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy | |||
| Mat window_data; | |||
| }; | |||
| } // namespace ncnn | |||
| #endif // LAYER_INVERSESPECTROGRAM_H | |||
| @@ -0,0 +1,221 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "spectrogram.h" | |||
| namespace ncnn { | |||
| Spectrogram::Spectrogram() | |||
| { | |||
| one_blob_only = true; | |||
| support_inplace = false; | |||
| } | |||
| int Spectrogram::load_param(const ParamDict& pd) | |||
| { | |||
| n_fft = pd.get(0, 0); | |||
| power = pd.get(1, 0); | |||
| hoplen = pd.get(2, n_fft / 4); | |||
| winlen = pd.get(3, n_fft); | |||
| window_type = pd.get(4, 0); | |||
| center = pd.get(5, 1); | |||
| pad_type = pd.get(6, 2); | |||
| normalized = pd.get(7, 0); | |||
| onesided = pd.get(8, 1); | |||
| // assert winlen <= n_fft | |||
| // generate window | |||
| window_data.create(normalized == 2 ? n_fft + 1 : n_fft); | |||
| { | |||
| float* p = window_data; | |||
| for (int i = 0; i < (n_fft - winlen) / 2; i++) | |||
| { | |||
| *p++ = 0.f; | |||
| } | |||
| if (window_type == 0) | |||
| { | |||
| // all ones | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 1.f; | |||
| } | |||
| } | |||
| if (window_type == 1) | |||
| { | |||
| // hann window | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen)); | |||
| } | |||
| } | |||
| if (window_type == 2) | |||
| { | |||
| // hamming window | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| *p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen); | |||
| } | |||
| } | |||
| for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++) | |||
| { | |||
| *p++ = 0.f; | |||
| } | |||
| // pre-calculated window norm factor | |||
| if (normalized == 2) | |||
| { | |||
| float sqsum = 0.f; | |||
| for (int i = 0; i < n_fft; i++) | |||
| { | |||
| sqsum += window_data[i] * window_data[i]; | |||
| } | |||
| window_data[n_fft] = 1.f / sqrt(sqsum); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const | |||
| { | |||
| // https://pytorch.org/audio/stable/generated/torchaudio.functional.spectrogram.html | |||
| // TODO custom window | |||
| Mat bottom_blob_bordered = bottom_blob; | |||
| if (center == 1) | |||
| { | |||
| Option opt_b = opt; | |||
| opt_b.blob_allocator = opt.workspace_allocator; | |||
| if (pad_type == 0) | |||
| copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_CONSTANT, 0.f, opt_b); | |||
| if (pad_type == 1) | |||
| copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REPLICATE, 0.f, opt_b); | |||
| if (pad_type == 2) | |||
| copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REFLECT, 0.f, opt_b); | |||
| } | |||
| const int size = bottom_blob_bordered.w; | |||
| // const int frames = size / hoplen + 1; | |||
| const int frames = (size - n_fft) / hoplen + 1; | |||
| const int freqs_onesided = n_fft / 2 + 1; | |||
| const int freqs = onesided ? freqs_onesided : n_fft; | |||
| const size_t elemsize = bottom_blob_bordered.elemsize; | |||
| if (power == 0) | |||
| { | |||
| top_blob.create(2, frames, freqs, elemsize, opt.blob_allocator); | |||
| } | |||
| else | |||
| { | |||
| top_blob.create(frames, freqs, elemsize, opt.blob_allocator); | |||
| } | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int i = 0; i < freqs_onesided; i++) | |||
| { | |||
| const float* ptr = bottom_blob_bordered; | |||
| float* outptr = power == 0 ? top_blob.channel(i) : top_blob.row(i); | |||
| for (int j = 0; j < frames; j++) | |||
| { | |||
| float re = 0.f; | |||
| float im = 0.f; | |||
| for (int k = 0; k < n_fft; k++) | |||
| { | |||
| float v = ptr[k]; | |||
| // apply window | |||
| v *= window_data[k]; | |||
| // dft | |||
| double angle = 2 * 3.14159265358979323846 * i * k / n_fft; | |||
| re += v * cosf(angle); // + imag * sinf(angle); | |||
| im -= v * sinf(angle); // + imag * cosf(angle); | |||
| } | |||
| if (normalized == 1) | |||
| { | |||
| float norm = 1.f / sqrt(n_fft); | |||
| re *= norm; | |||
| im *= norm; | |||
| } | |||
| if (normalized == 2) | |||
| { | |||
| float norm = window_data[n_fft]; | |||
| re *= norm; | |||
| im *= norm; | |||
| } | |||
| if (power == 0) | |||
| { | |||
| // complex as real | |||
| outptr[0] = re; | |||
| outptr[1] = im; | |||
| outptr += 2; | |||
| } | |||
| if (power == 1) | |||
| { | |||
| // magnitude | |||
| outptr[0] = sqrt(re * re + im * im); | |||
| outptr += 1; | |||
| } | |||
| if (power == 2) | |||
| { | |||
| outptr[0] = re * re + im * im; | |||
| outptr += 1; | |||
| } | |||
| ptr += hoplen; | |||
| } | |||
| } | |||
| if (!onesided) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int i = freqs_onesided; i < n_fft; i++) | |||
| { | |||
| if (power == 0) | |||
| { | |||
| const float* ptr = top_blob.channel(n_fft - i); | |||
| float* outptr = top_blob.channel(i); | |||
| for (int j = 0; j < frames; j++) | |||
| { | |||
| // complex as real | |||
| outptr[0] = ptr[0]; | |||
| outptr[1] = -ptr[1]; | |||
| ptr += 2; | |||
| outptr += 2; | |||
| } | |||
| } | |||
| else // if (power == 1 || power == 2) | |||
| { | |||
| const float* ptr = top_blob.row(n_fft - i); | |||
| float* outptr = top_blob.row(i); | |||
| memcpy(outptr, ptr, frames * sizeof(float)); | |||
| } | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,47 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef LAYER_SPECTROGRAM_H | |||
| #define LAYER_SPECTROGRAM_H | |||
| #include "layer.h" | |||
| namespace ncnn { | |||
| class Spectrogram : public Layer | |||
| { | |||
| public: | |||
| Spectrogram(); | |||
| virtual int load_param(const ParamDict& pd); | |||
| virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; | |||
| public: | |||
| int n_fft; | |||
| int power; | |||
| int hoplen; | |||
| int winlen; | |||
| int window_type; // 0=ones 1=hann 2=hamming | |||
| int center; | |||
| int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT | |||
| int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy | |||
| int onesided; | |||
| Mat window_data; | |||
| }; | |||
| } // namespace ncnn | |||
| #endif // LAYER_SPECTROGRAM_H | |||
| @@ -117,6 +117,7 @@ ncnn_add_layer_test(HardSwish) | |||
| ncnn_add_layer_test(InnerProduct) | |||
| ncnn_add_layer_test(InstanceNorm) | |||
| ncnn_add_layer_test(Interp) | |||
| ncnn_add_layer_test(InverseSpectrogram) | |||
| ncnn_add_layer_test(LayerNorm) | |||
| ncnn_add_layer_test(LRN) | |||
| ncnn_add_layer_test(LSTM) | |||
| @@ -154,6 +155,7 @@ ncnn_add_layer_test(Sigmoid) | |||
| ncnn_add_layer_test(Slice) | |||
| ncnn_add_layer_test(Softmax) | |||
| ncnn_add_layer_test(Softplus) | |||
| ncnn_add_layer_test(Spectrogram) | |||
| ncnn_add_layer_test(Squeeze) | |||
| ncnn_add_layer_test(Swish) | |||
| ncnn_add_layer_test(TanH) | |||
| @@ -0,0 +1,56 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "testutil.h" | |||
| static int test_inversespectrogram(int frames, int freqs, int n_fft, int returns, int hoplen, int winlen, int window_type, int center, int normalized) | |||
| { | |||
| ncnn::Mat a = RandomMat(2, frames, freqs); | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, n_fft); | |||
| pd.set(1, returns); | |||
| pd.set(2, hoplen); | |||
| pd.set(3, winlen); | |||
| pd.set(4, window_type); | |||
| pd.set(5, center); | |||
| pd.set(7, normalized); | |||
| std::vector<ncnn::Mat> weights(0); | |||
| int ret = test_layer("InverseSpectrogram", pd, weights, a); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_inversespectrogram failed frames=%d freqs=%d n_fft=%d returns=%d hoplen=%d winlen=%d window_type=%d center=%d normalized=%d\n", frames, freqs, n_fft, returns, hoplen, winlen, window_type, center, normalized); | |||
| } | |||
| return ret; | |||
| } | |||
| static int test_inversespectrogram_0() | |||
| { | |||
| return 0 | |||
| || test_inversespectrogram(17, 1, 1, 0, 1, 1, 0, 1, 0) | |||
| || test_inversespectrogram(39, 9, 17, 0, 7, 15, 0, 0, 1) | |||
| || test_inversespectrogram(128, 6, 10, 0, 2, 7, 1, 1, 1) | |||
| || test_inversespectrogram(255, 17, 17, 1, 14, 17, 2, 0, 0) | |||
| || test_inversespectrogram(124, 28, 55, 2, 12, 55, 1, 1, 2); | |||
| } | |||
| int main() | |||
| { | |||
| SRAND(7767517); | |||
| return test_inversespectrogram_0(); | |||
| } | |||
| @@ -0,0 +1,58 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "testutil.h" | |||
| static int test_spectrogram(int size, int n_fft, int power, int hoplen, int winlen, int window_type, int center, int pad_type, int normalized, int onesided) | |||
| { | |||
| ncnn::Mat a = RandomMat(size); | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, n_fft); | |||
| pd.set(1, power); | |||
| pd.set(2, hoplen); | |||
| pd.set(3, winlen); | |||
| pd.set(4, window_type); | |||
| pd.set(5, center); | |||
| pd.set(6, pad_type); | |||
| pd.set(7, normalized); | |||
| pd.set(8, onesided); | |||
| std::vector<ncnn::Mat> weights(0); | |||
| int ret = test_layer("Spectrogram", pd, weights, a); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_spectrogram failed size=%d n_fft=%d power=%d hoplen=%d winlen=%d window_type=%d center=%d pad_type=%d normalized=%d onesided=%d\n", size, n_fft, power, hoplen, winlen, window_type, center, pad_type, normalized, onesided); | |||
| } | |||
| return ret; | |||
| } | |||
| static int test_spectrogram_0() | |||
| { | |||
| return 0 | |||
| || test_spectrogram(17, 1, 0, 1, 1, 0, 1, 0, 0, 0) | |||
| || test_spectrogram(39, 17, 0, 7, 15, 0, 0, 0, 1, 0) | |||
| || test_spectrogram(128, 10, 0, 2, 7, 1, 1, 1, 1, 1) | |||
| || test_spectrogram(255, 17, 1, 14, 17, 2, 0, 0, 0, 1) | |||
| || test_spectrogram(124, 55, 2, 12, 55, 1, 1, 2, 2, 0); | |||
| } | |||
| int main() | |||
| { | |||
| SRAND(7767517); | |||
| return test_spectrogram_0(); | |||
| } | |||
| @@ -306,6 +306,9 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/nn_quantized_FloatFunctional.cpp | |||
| pass_level2/torchaudio_F_inverse_spectrogram.cpp | |||
| pass_level2/torchaudio_F_spectrogram.cpp | |||
| pass_level2/nn_GRU.cpp | |||
| pass_level2/nn_LSTM.cpp | |||
| pass_level2/nn_RNN.cpp | |||
| @@ -570,6 +573,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/torch_cumsum.cpp | |||
| pass_ncnn/torch_diag.cpp | |||
| pass_ncnn/torch_flatten.cpp | |||
| pass_ncnn/torch_istft.cpp | |||
| pass_ncnn/torch_logsumexp.cpp | |||
| pass_ncnn/torch_matmul.cpp | |||
| pass_ncnn/torch_max.cpp | |||
| @@ -582,9 +586,12 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/torch_slice_scatter.cpp | |||
| pass_ncnn/torch_squeeze.cpp | |||
| pass_ncnn/torch_sum.cpp | |||
| pass_ncnn/torch_stft.cpp | |||
| pass_ncnn/torch_t.cpp | |||
| pass_ncnn/torch_transpose.cpp | |||
| pass_ncnn/torch_unsqueeze.cpp | |||
| pass_ncnn/torchaudio_F_inverse_spectrogram.cpp | |||
| pass_ncnn/torchaudio_F_spectrogram.cpp | |||
| pass_ncnn/torchvision_DeformConv2d.cpp | |||
| ) | |||
| @@ -1458,6 +1458,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| fprintf(pyfp, "import torch.nn.functional as F\n"); | |||
| fprintf(pyfp, "try:\n"); | |||
| fprintf(pyfp, " import torchvision\n"); | |||
| fprintf(pyfp, " import torchaudio\n"); | |||
| fprintf(pyfp, "except:\n"); | |||
| fprintf(pyfp, " pass\n"); | |||
| @@ -43,6 +43,7 @@ pnnx.Output output 1 0 out | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| op->params["pad_mode"] = "reflect"; | |||
| op->params["center"] = false; | |||
| } | |||
| }; | |||
| @@ -55,7 +56,7 @@ public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 37 36 | |||
| 36 35 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| @@ -79,19 +80,18 @@ prim::Constant op_11 0 1 21 value=%pad_left | |||
| prim::Constant op_12 0 1 63 value=%pad_right | |||
| prim::ListConstruct op_13 2 1 21 63 22 | |||
| prim::Constant op_14 0 1 23 value=%pad_mode | |||
| prim::Constant op_15 0 1 24 value=None | |||
| aten::pad op_16 4 1 a 22 23 24 b | |||
| prim::Constant op_17 0 1 64 value=1 | |||
| aten::size op_18 2 1 b 64 27 | |||
| prim::NumToTensor op_19 1 1 27 28 | |||
| aten::Int op_20 1 1 28 31 | |||
| prim::Constant op_21 0 1 33 value=2 | |||
| aten::size op_22 2 1 b 33 34 | |||
| prim::NumToTensor op_23 1 1 34 35 | |||
| aten::Int op_24 1 1 35 40 | |||
| prim::ListConstruct op_25 2 1 31 40 41 | |||
| aten::view op_26 2 1 b 41 c | |||
| aten::stft op_27 8 1 c n_fft hop_length win_length window normalized onesided return_complex out | |||
| F.pad op_15 3 1 a 22 23 b | |||
| prim::Constant op_16 0 1 64 value=1 | |||
| aten::size op_17 2 1 b 64 27 | |||
| prim::NumToTensor op_18 1 1 27 28 | |||
| aten::Int op_29 1 1 28 31 | |||
| prim::Constant op_20 0 1 33 value=2 | |||
| aten::size op_21 2 1 b 33 34 | |||
| prim::NumToTensor op_22 1 1 34 35 | |||
| aten::Int op_23 1 1 35 40 | |||
| prim::ListConstruct op_24 2 1 31 40 41 | |||
| aten::view op_25 2 1 b 41 c | |||
| aten::stft op_26 8 1 c n_fft hop_length win_length window normalized onesided return_complex out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| @@ -110,4 +110,88 @@ pnnx.Output output 1 0 out | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_1, 19) | |||
| class torch_stft_2 : public torch_stft_1 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 29 28 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 normalized | |||
| pnnx.Input input_6 0 1 onesided | |||
| pnnx.Input input_7 0 1 return_complex | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 input 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 18 | |||
| prim::Constant op_4 0 1 15 value=1 | |||
| prim::Constant op_5 0 1 121 value=1 | |||
| prim::ListConstruct op_6 3 1 15 121 18 19 | |||
| aten::view op_7 2 1 input 19 a | |||
| prim::Constant op_8 0 1 22 value=%pad_left | |||
| prim::Constant op_9 0 1 122 value=%pad_right | |||
| prim::ListConstruct op_10 2 1 22 122 23 | |||
| prim::Constant op_11 0 1 24 value=%pad_mode | |||
| F.pad op_12 3 1 a 23 24 b | |||
| prim::Constant op_13 0 1 28 value=2 | |||
| aten::size op_14 2 1 b 28 29 | |||
| prim::NumToTensor op_15 1 1 29 30 | |||
| aten::Int op_16 1 1 30 34 | |||
| prim::ListConstruct op_17 1 1 34 35 | |||
| aten::view op_18 2 1 b 35 c | |||
| aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_2, 19) | |||
| class torch_stft_3 : public torch_stft_1 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 29 28 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 normalized | |||
| pnnx.Input input_6 0 1 onesided | |||
| pnnx.Input input_7 0 1 return_complex | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 input 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 18 | |||
| prim::Constant op_4 0 1 15 value=1 | |||
| prim::Constant op_5 0 1 121 value=1 | |||
| prim::ListConstruct op_6 3 1 15 121 18 19 | |||
| aten::view op_7 2 1 input 19 a | |||
| prim::Constant op_8 0 1 22 value=%pad_left | |||
| prim::Constant op_9 0 1 122 value=%pad_right | |||
| prim::ListConstruct op_10 2 1 22 122 23 | |||
| prim::Constant op_11 0 1 24 value=None | |||
| F.pad op_12 3 1 a 23 24 b mode=%pad_mode | |||
| prim::Constant op_13 0 1 28 value=2 | |||
| aten::size op_14 2 1 b 28 29 | |||
| prim::NumToTensor op_15 1 1 29 30 | |||
| aten::Int op_16 1 1 30 34 | |||
| prim::ListConstruct op_17 1 1 34 35 | |||
| aten::view op_18 2 1 b 35 c | |||
| aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_3, 19) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,165 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torchaudio_F_inverse_spectrogram : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 29 28 | |||
| pnnx.Input input_0 0 1 spectrogram | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 center | |||
| pnnx.Input input_6 0 1 onesided | |||
| prim::Constant op_0 0 1 13 value=0 | |||
| aten::size op_1 2 1 spectrogram 13 14 | |||
| prim::NumToTensor op_2 1 1 14 15 | |||
| aten::Int op_3 1 1 15 18 | |||
| prim::Constant op_4 0 1 20 value=1 | |||
| aten::size op_5 2 1 spectrogram 20 21 | |||
| prim::NumToTensor op_6 1 1 21 22 | |||
| aten::Int op_7 1 1 22 28 | |||
| prim::Constant op_8 0 1 24 value=-1 | |||
| prim::ListConstruct op_9 3 1 24 18 28 29 | |||
| aten::reshape op_10 2 1 spectrogram 29 spectrogram.1 | |||
| prim::Constant op_11 0 1 normalized value=%normalized | |||
| prim::Constant op_12 0 1 length value=None | |||
| prim::Constant op_13 0 1 return_complex value=False | |||
| aten::istft op_14 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1 | |||
| prim::Constant op_15 0 1 75 value=1 | |||
| aten::size op_16 2 1 waveform.1 75 42 | |||
| prim::NumToTensor op_17 1 1 42 43 | |||
| aten::Int op_18 1 1 43 47 | |||
| prim::ListConstruct op_19 1 1 47 48 | |||
| aten::reshape op_20 2 1 waveform.1 48 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.inverse_spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["length"] = Parameter(); | |||
| op->params["pad"] = 0; | |||
| if (captured_params.at("normalized").b) | |||
| { | |||
| op->params["normalized"] = "frame_length"; | |||
| } | |||
| else | |||
| { | |||
| op->params["normalized"] = false; | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 6) | |||
| class torchaudio_F_inverse_spectrogram_0 : public torchaudio_F_inverse_spectrogram | |||
| { | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 33 32 | |||
| pnnx.Input input_0 0 1 spectrogram | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 center | |||
| pnnx.Input input_6 0 1 onesided | |||
| prim::Constant op_0 0 1 13 value=0 | |||
| aten::size op_1 2 1 spectrogram 13 14 | |||
| prim::NumToTensor op_2 1 1 14 15 | |||
| aten::Int op_3 1 1 15 18 | |||
| prim::Constant op_4 0 1 20 value=1 | |||
| aten::size op_5 2 1 spectrogram 20 21 | |||
| prim::NumToTensor op_6 1 1 21 22 | |||
| aten::Int op_7 1 1 22 25 | |||
| prim::Constant op_8 0 1 27 value=2 | |||
| aten::size op_9 2 1 spectrogram 27 28 | |||
| prim::NumToTensor op_10 1 1 28 29 | |||
| aten::Int op_11 1 1 29 35 | |||
| prim::Constant op_12 0 1 31 value=-1 | |||
| prim::ListConstruct op_13 3 1 31 25 35 36 | |||
| aten::reshape op_14 2 1 spectrogram 36 spectrogram.1 | |||
| prim::Constant op_15 0 1 normalized value=%normalized | |||
| prim::Constant op_16 0 1 length value=None | |||
| prim::Constant op_17 0 1 return_complex value=False | |||
| aten::istft op_18 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1 | |||
| prim::Constant op_19 0 1 83 value=1 | |||
| aten::size op_20 2 1 waveform.1 83 49 | |||
| prim::NumToTensor op_21 1 1 49 50 | |||
| aten::Int op_22 1 1 50 55 | |||
| prim::ListConstruct op_23 2 1 18 55 56 | |||
| aten::reshape op_24 2 1 waveform.1 56 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_0, 6) | |||
| class torchaudio_F_inverse_spectrogram_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 15 14 | |||
| pnnx.Input input_0 0 1 spectrogram | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 center | |||
| pnnx.Input input_6 0 1 onesided | |||
| prim::Constant op_0 0 1 13 value=2.000000e+00 | |||
| aten::pow op_1 2 1 window 13 14 | |||
| prim::Constant op_2 0 1 87 value=None | |||
| aten::sum op_3 2 1 14 87 16 | |||
| aten::sqrt op_4 1 1 16 17 | |||
| aten::mul op_5 2 1 spectrogram 17 spectrogram.1 | |||
| torchaudio.functional.inverse_spectrogram op_6 7 1 spectrogram.1 n_fft hop_length win_length window center onesided out normalized=False length=%length pad=%pad | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.inverse_spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["length"] = captured_params.at("length"); | |||
| op->params["pad"] = captured_params.at("pad"); | |||
| op->params["normalized"] = "window"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_1, 7) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,709 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torchaudio_F_spectrogram : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 27 26 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 waveform 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 18 | |||
| prim::Constant op_4 0 1 15 value=-1 | |||
| prim::ListConstruct op_5 2 1 15 18 19 | |||
| aten::reshape op_6 2 1 waveform 19 waveform.1 | |||
| prim::Constant op_7 0 1 normalized value=%normalized | |||
| prim::Constant op_8 0 1 return_complex value=True | |||
| aten::stft op_9 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_10 0 1 29 value=1 | |||
| aten::size op_11 2 1 spec_f.1 29 30 | |||
| prim::NumToTensor op_12 1 1 30 31 | |||
| aten::Int op_13 1 1 31 34 | |||
| prim::Constant op_14 0 1 36 value=2 | |||
| aten::size op_15 2 1 spec_f.1 36 37 | |||
| prim::NumToTensor op_16 1 1 37 38 | |||
| aten::Int op_17 1 1 38 43 | |||
| prim::ListConstruct op_18 2 1 34 43 44 | |||
| aten::reshape op_19 2 1 spec_f.1 44 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = 0; | |||
| op->params["pad_mode"] = "reflect"; | |||
| op->params["center"] = false; | |||
| op->params["power"] = Parameter(); | |||
| if (captured_params.at("normalized").b) | |||
| { | |||
| op->params["normalized"] = "frame_length"; | |||
| } | |||
| else | |||
| { | |||
| op->params["normalized"] = false; | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 6) | |||
| class torchaudio_F_spectrogram_0 : public torchaudio_F_spectrogram | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 31 30 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 waveform 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 16 | |||
| prim::Constant op_4 0 1 18 value=1 | |||
| aten::size op_5 2 1 waveform 18 19 | |||
| prim::NumToTensor op_6 1 1 19 20 | |||
| aten::Int op_7 1 1 20 25 | |||
| prim::Constant op_8 0 1 22 value=-1 | |||
| prim::ListConstruct op_9 2 1 22 25 26 | |||
| aten::reshape op_10 2 1 waveform 26 waveform.1 | |||
| prim::Constant op_11 0 1 normalized value=%normalized | |||
| prim::Constant op_12 0 1 return_complex value=True | |||
| aten::stft op_13 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_14 0 1 72 value=1 | |||
| aten::size op_15 2 1 spec_f.1 72 36 | |||
| prim::NumToTensor op_16 1 1 36 37 | |||
| aten::Int op_17 1 1 37 40 | |||
| prim::Constant op_18 0 1 42 value=2 | |||
| aten::size op_19 2 1 spec_f.1 42 43 | |||
| prim::NumToTensor op_20 1 1 43 44 | |||
| aten::Int op_21 1 1 44 50 | |||
| prim::ListConstruct op_22 3 1 16 40 50 51 | |||
| aten::reshape op_23 2 1 spec_f.1 51 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_0, 6) | |||
| class torchaudio_F_spectrogram_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 58 57 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 18 value=1 | |||
| aten::size op_1 2 1 waveform 18 19 | |||
| prim::NumToTensor op_2 1 1 19 20 | |||
| aten::Int op_3 1 1 20 25 | |||
| prim::Constant op_4 0 1 22 value=-1 | |||
| prim::ListConstruct op_5 2 1 22 25 26 | |||
| aten::reshape op_6 2 1 waveform 26 waveform.1 | |||
| prim::Constant op_7 0 1 106 value=0 | |||
| aten::size op_8 2 1 waveform.1 106 29 | |||
| prim::NumToTensor op_9 1 1 29 30 | |||
| aten::Int op_10 1 1 30 33 | |||
| prim::Constant op_11 0 1 107 value=1 | |||
| aten::size op_12 2 1 waveform.1 107 35 | |||
| prim::NumToTensor op_13 1 1 35 36 | |||
| aten::Int op_14 1 1 36 41 | |||
| prim::Constant op_15 0 1 108 value=1 | |||
| prim::ListConstruct op_16 3 1 108 33 41 42 | |||
| aten::view op_17 2 1 waveform.1 42 input0.1 | |||
| prim::Constant op_18 0 1 45 value=%pad_left | |||
| prim::Constant op_19 0 1 109 value=%pad_right | |||
| prim::ListConstruct op_20 2 1 45 109 46 | |||
| prim::Constant op_21 0 1 47 value=%pad_mode | |||
| prim::Constant op_22 0 1 110 value=None | |||
| aten::pad op_23 4 1 input0.1 46 47 110 input1.1 | |||
| prim::Constant op_24 0 1 111 value=1 | |||
| aten::size op_25 2 1 input1.1 111 51 | |||
| prim::NumToTensor op_26 1 1 51 52 | |||
| aten::Int op_27 1 1 52 55 | |||
| prim::Constant op_28 0 1 57 value=2 | |||
| aten::size op_29 2 1 input1.1 57 58 | |||
| prim::NumToTensor op_30 1 1 58 59 | |||
| aten::Int op_31 1 1 59 64 | |||
| prim::ListConstruct op_32 2 1 55 64 65 | |||
| aten::view op_33 2 1 input1.1 65 input2.1 | |||
| prim::Constant op_34 0 1 normalized value=%normalized | |||
| prim::Constant op_35 0 1 return_complex value=True | |||
| aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_37 0 1 11 value=0 | |||
| aten::size op_38 2 1 waveform 11 12 | |||
| prim::NumToTensor op_39 1 1 12 13 | |||
| aten::Int op_40 1 1 13 16 | |||
| prim::Constant op_41 0 1 116 value=1 | |||
| aten::size op_42 2 1 spec_f.1 116 75 | |||
| prim::NumToTensor op_43 1 1 75 76 | |||
| aten::Int op_44 1 1 76 79 | |||
| prim::Constant op_45 0 1 117 value=2 | |||
| aten::size op_46 2 1 spec_f.1 117 81 | |||
| prim::NumToTensor op_47 1 1 81 82 | |||
| aten::Int op_48 1 1 82 88 | |||
| prim::ListConstruct op_49 3 1 16 79 88 89 | |||
| aten::reshape op_50 2 1 spec_f.1 89 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = 0; | |||
| op->params["pad_mode"] = captured_params.at("pad_mode"); | |||
| op->params["center"] = true; | |||
| op->params["power"] = Parameter(); | |||
| if (captured_params.at("normalized").b) | |||
| { | |||
| op->params["normalized"] = "frame_length"; | |||
| } | |||
| else | |||
| { | |||
| op->params["normalized"] = false; | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 6) | |||
| class torchaudio_F_spectrogram_1_1 : public torchaudio_F_spectrogram_1 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 63 62 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 waveform 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 18 | |||
| prim::Constant op_4 0 1 15 value=-1 | |||
| prim::ListConstruct op_5 2 1 15 18 19 | |||
| aten::reshape op_6 2 1 waveform 19 waveform.1 | |||
| prim::Constant op_7 0 1 108 value=0 | |||
| aten::size op_8 2 1 waveform.1 108 22 | |||
| prim::NumToTensor op_9 1 1 22 23 | |||
| aten::Int op_10 1 1 23 26 | |||
| prim::Constant op_11 0 1 28 value=1 | |||
| aten::size op_12 2 1 waveform.1 28 29 | |||
| prim::NumToTensor op_13 1 1 29 30 | |||
| aten::Int op_14 1 1 30 35 | |||
| prim::Constant op_15 0 1 109 value=1 | |||
| prim::ListConstruct op_16 3 1 109 26 35 36 | |||
| aten::view op_17 2 1 waveform.1 36 input0.1 | |||
| prim::Constant op_18 0 1 39 value=%pad_left | |||
| prim::Constant op_19 0 1 110 value=%pad_right | |||
| prim::ListConstruct op_20 2 1 39 110 40 | |||
| prim::Constant op_21 0 1 41 value=%pad_mode | |||
| prim::Constant op_22 0 1 111 value=None | |||
| aten::pad op_23 4 1 input0.1 40 41 111 input1.1 | |||
| prim::Constant op_24 0 1 112 value=1 | |||
| aten::size op_25 2 1 input1.1 112 45 | |||
| prim::NumToTensor op_26 1 1 45 46 | |||
| aten::Int op_27 1 1 46 49 | |||
| prim::Constant op_28 0 1 51 value=2 | |||
| aten::size op_29 2 1 input1.1 51 52 | |||
| prim::NumToTensor op_30 1 1 52 53 | |||
| aten::Int op_31 1 1 53 58 | |||
| prim::ListConstruct op_32 2 1 49 58 59 | |||
| aten::view op_33 2 1 input1.1 59 input2.1 | |||
| prim::Constant op_34 0 1 normalized value=%normalized | |||
| prim::Constant op_35 0 1 return_complex value=True | |||
| aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_37 0 1 117 value=1 | |||
| aten::size op_38 2 1 spec_f.1 117 69 | |||
| prim::NumToTensor op_39 1 1 69 70 | |||
| aten::Int op_40 1 1 70 73 | |||
| prim::Constant op_50 0 1 118 value=2 | |||
| aten::size op_51 2 1 spec_f.1 118 75 | |||
| prim::NumToTensor op_52 1 1 75 76 | |||
| aten::Int op_53 1 1 76 81 | |||
| prim::ListConstruct op_54 2 1 73 81 82 | |||
| aten::reshape op_55 2 1 spec_f.1 82 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_1, 6) | |||
| class torchaudio_F_spectrogram_1_2 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 52 51 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 211 value=0 | |||
| aten::size op_1 2 1 waveform 211 107 | |||
| prim::NumToTensor op_2 1 1 107 108 | |||
| aten::Int op_3 1 1 108 112 | |||
| prim::Constant op_4 0 1 212 value=-1 | |||
| prim::ListConstruct op_5 2 1 212 112 113 | |||
| aten::reshape op_6 2 1 waveform 113 input3.1 | |||
| prim::Constant op_7 0 1 213 value=0 | |||
| aten::size op_8 2 1 input3.1 213 116 | |||
| prim::NumToTensor op_9 1 1 116 117 | |||
| aten::Int op_10 1 1 117 120 | |||
| prim::Constant op_11 0 1 214 value=1 | |||
| aten::size op_12 2 1 input3.1 214 122 | |||
| prim::NumToTensor op_13 1 1 122 123 | |||
| aten::Int op_14 1 1 123 128 | |||
| prim::Constant op_15 0 1 215 value=1 | |||
| prim::ListConstruct op_16 3 1 215 120 128 129 | |||
| aten::view op_17 2 1 input3.1 129 input4.1 | |||
| prim::Constant op_18 0 1 216 value=%pad_left | |||
| prim::Constant op_19 0 1 217 value=%pad_right | |||
| prim::ListConstruct op_20 2 1 216 217 132 | |||
| aten::reflection_pad1d op_21 2 1 input4.1 132 input5.1 | |||
| prim::Constant op_22 0 1 218 value=1 | |||
| aten::size op_23 2 1 input5.1 218 135 | |||
| prim::NumToTensor op_24 1 1 135 136 | |||
| aten::Int op_25 1 1 136 139 | |||
| prim::Constant op_26 0 1 219 value=2 | |||
| aten::size op_27 2 1 input5.1 219 141 | |||
| prim::NumToTensor op_28 1 1 141 142 | |||
| aten::Int op_29 1 1 142 147 | |||
| prim::ListConstruct op_30 2 1 139 147 148 | |||
| aten::view op_31 2 1 input5.1 148 input6.1 | |||
| prim::Constant op_32 0 1 normalized value=%normalized | |||
| prim::Constant op_33 0 1 return_complex value=True | |||
| aten::stft op_34 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1 | |||
| prim::Constant op_35 0 1 226 value=1 | |||
| aten::size op_36 2 1 spec_f2.1 226 157 | |||
| prim::NumToTensor op_37 1 1 157 158 | |||
| aten::Int op_38 1 1 158 161 | |||
| prim::Constant op_39 0 1 227 value=2 | |||
| aten::size op_40 2 1 spec_f2.1 227 163 | |||
| prim::NumToTensor op_41 1 1 163 164 | |||
| aten::Int op_42 1 1 164 169 | |||
| prim::ListConstruct op_43 2 1 161 169 170 | |||
| aten::reshape op_44 2 1 spec_f2.1 170 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = 0; | |||
| op->params["pad_mode"] = "reflect"; | |||
| op->params["center"] = true; | |||
| op->params["power"] = Parameter(); | |||
| if (captured_params.at("normalized").b) | |||
| { | |||
| op->params["normalized"] = "frame_length"; | |||
| } | |||
| else | |||
| { | |||
| op->params["normalized"] = false; | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_2, 6) | |||
| class torchaudio_F_spectrogram_1_3 : public torchaudio_F_spectrogram_1_2 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 56 55 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 waveform 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 16 | |||
| prim::Constant op_4 0 1 18 value=1 | |||
| aten::size op_5 2 1 waveform 18 19 | |||
| prim::NumToTensor op_6 1 1 19 20 | |||
| aten::Int op_7 1 1 20 25 | |||
| prim::Constant op_8 0 1 22 value=-1 | |||
| prim::ListConstruct op_9 2 1 22 25 26 | |||
| aten::reshape op_10 2 1 waveform 26 input.1 | |||
| prim::Constant op_11 0 1 326 value=0 | |||
| aten::size op_12 2 1 input.1 326 29 | |||
| prim::NumToTensor op_13 1 1 29 30 | |||
| aten::Int op_14 1 1 30 33 | |||
| prim::Constant op_15 0 1 327 value=1 | |||
| aten::size op_16 2 1 input.1 327 35 | |||
| prim::NumToTensor op_17 1 1 35 36 | |||
| aten::Int op_18 1 1 36 41 | |||
| prim::Constant op_19 0 1 328 value=1 | |||
| prim::ListConstruct op_20 3 1 328 33 41 42 | |||
| aten::view op_21 2 1 input.1 42 input0.1 | |||
| prim::Constant op_22 0 1 45 value=%pad_left | |||
| prim::Constant op_23 0 1 329 value=%pad_right | |||
| prim::ListConstruct op_24 2 1 45 329 46 | |||
| aten::reflection_pad1d op_25 2 1 input0.1 46 input1.1 | |||
| prim::Constant op_26 0 1 330 value=1 | |||
| aten::size op_27 2 1 input1.1 330 49 | |||
| prim::NumToTensor op_28 1 1 49 50 | |||
| aten::Int op_29 1 1 50 53 | |||
| prim::Constant op_30 0 1 55 value=2 | |||
| aten::size op_31 2 1 input1.1 55 56 | |||
| prim::NumToTensor op_32 1 1 56 57 | |||
| aten::Int op_33 1 1 57 62 | |||
| prim::ListConstruct op_34 2 1 53 62 63 | |||
| aten::view op_35 2 1 input1.1 63 input2.1 | |||
| prim::Constant op_36 0 1 normalized value=%normalized | |||
| prim::Constant op_37 0 1 return_complex value=True | |||
| aten::stft op_38 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_39 0 1 334 value=1 | |||
| aten::size op_40 2 1 spec_f.1 334 74 | |||
| prim::NumToTensor op_41 1 1 74 75 | |||
| aten::Int op_42 1 1 75 78 | |||
| prim::Constant op_43 0 1 335 value=2 | |||
| aten::size op_44 2 1 spec_f.1 335 80 | |||
| prim::NumToTensor op_45 1 1 80 81 | |||
| aten::Int op_46 1 1 81 87 | |||
| prim::ListConstruct op_47 3 1 16 78 87 88 | |||
| aten::reshape op_48 2 1 spec_f.1 88 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_3, 6) | |||
| class torchaudio_F_spectrogram_1_4 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 53 52 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 211 value=0 | |||
| aten::size op_1 2 1 waveform 211 107 | |||
| prim::NumToTensor op_2 1 1 107 108 | |||
| aten::Int op_3 1 1 108 112 | |||
| prim::Constant op_4 0 1 212 value=-1 | |||
| prim::ListConstruct op_5 2 1 212 112 113 | |||
| aten::reshape op_6 2 1 waveform 113 input3.1 | |||
| prim::Constant op_7 0 1 213 value=0 | |||
| aten::size op_8 2 1 input3.1 213 116 | |||
| prim::NumToTensor op_9 1 1 116 117 | |||
| aten::Int op_10 1 1 117 120 | |||
| prim::Constant op_11 0 1 214 value=1 | |||
| aten::size op_12 2 1 input3.1 214 122 | |||
| prim::NumToTensor op_13 1 1 122 123 | |||
| aten::Int op_14 1 1 123 128 | |||
| prim::Constant op_15 0 1 215 value=1 | |||
| prim::ListConstruct op_16 3 1 215 120 128 129 | |||
| aten::view op_17 2 1 input3.1 129 input4.1 | |||
| prim::Constant op_18 0 1 216 value=%pad_left | |||
| prim::Constant op_19 0 1 217 value=%pad_right | |||
| prim::ListConstruct op_20 2 1 216 217 132 | |||
| prim::Constant op_21 0 1 46 value=0.000000e+00 | |||
| aten::constant_pad_nd op_22 3 1 input4.1 132 46 input5.1 | |||
| prim::Constant op_23 0 1 218 value=1 | |||
| aten::size op_24 2 1 input5.1 218 135 | |||
| prim::NumToTensor op_25 1 1 135 136 | |||
| aten::Int op_26 1 1 136 139 | |||
| prim::Constant op_27 0 1 219 value=2 | |||
| aten::size op_28 2 1 input5.1 219 141 | |||
| prim::NumToTensor op_29 1 1 141 142 | |||
| aten::Int op_30 1 1 142 147 | |||
| prim::ListConstruct op_31 2 1 139 147 148 | |||
| aten::view op_32 2 1 input5.1 148 input6.1 | |||
| prim::Constant op_33 0 1 normalized value=%normalized | |||
| prim::Constant op_34 0 1 return_complex value=True | |||
| aten::stft op_35 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1 | |||
| prim::Constant op_36 0 1 226 value=1 | |||
| aten::size op_37 2 1 spec_f2.1 226 157 | |||
| prim::NumToTensor op_38 1 1 157 158 | |||
| aten::Int op_39 1 1 158 161 | |||
| prim::Constant op_40 0 1 227 value=2 | |||
| aten::size op_41 2 1 spec_f2.1 227 163 | |||
| prim::NumToTensor op_42 1 1 163 164 | |||
| aten::Int op_43 1 1 164 169 | |||
| prim::ListConstruct op_44 2 1 161 169 170 | |||
| aten::reshape op_45 2 1 spec_f2.1 170 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = 0; | |||
| op->params["pad_mode"] = "constant"; | |||
| op->params["center"] = true; | |||
| op->params["power"] = Parameter(); | |||
| if (captured_params.at("normalized").b) | |||
| { | |||
| op->params["normalized"] = "frame_length"; | |||
| } | |||
| else | |||
| { | |||
| op->params["normalized"] = false; | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_4, 6) | |||
| class torchaudio_F_spectrogram_1_5 : public torchaudio_F_spectrogram_1_4 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 57 56 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| prim::Constant op_0 0 1 11 value=0 | |||
| aten::size op_1 2 1 waveform 11 12 | |||
| prim::NumToTensor op_2 1 1 12 13 | |||
| aten::Int op_3 1 1 13 16 | |||
| prim::Constant op_4 0 1 18 value=1 | |||
| aten::size op_5 2 1 waveform 18 19 | |||
| prim::NumToTensor op_6 1 1 19 20 | |||
| aten::Int op_7 1 1 20 25 | |||
| prim::Constant op_8 0 1 22 value=-1 | |||
| prim::ListConstruct op_9 2 1 22 25 26 | |||
| aten::reshape op_10 2 1 waveform 26 input.1 | |||
| prim::Constant op_11 0 1 326 value=0 | |||
| aten::size op_12 2 1 input.1 326 29 | |||
| prim::NumToTensor op_13 1 1 29 30 | |||
| aten::Int op_14 1 1 30 33 | |||
| prim::Constant op_15 0 1 327 value=1 | |||
| aten::size op_16 2 1 input.1 327 35 | |||
| prim::NumToTensor op_17 1 1 35 36 | |||
| aten::Int op_18 1 1 36 41 | |||
| prim::Constant op_19 0 1 328 value=1 | |||
| prim::ListConstruct op_20 3 1 328 33 41 42 | |||
| aten::view op_21 2 1 input.1 42 input0.1 | |||
| prim::Constant op_22 0 1 45 value=%pad_left | |||
| prim::Constant op_23 0 1 329 value=%pad_right | |||
| prim::ListConstruct op_24 2 1 45 329 46 | |||
| prim::Constant op_25 0 1 47 value=0.000000e+00 | |||
| aten::constant_pad_nd op_26 3 1 input0.1 46 47 input1.1 | |||
| prim::Constant op_27 0 1 330 value=1 | |||
| aten::size op_28 2 1 input1.1 330 49 | |||
| prim::NumToTensor op_29 1 1 49 50 | |||
| aten::Int op_30 1 1 50 53 | |||
| prim::Constant op_31 0 1 55 value=2 | |||
| aten::size op_32 2 1 input1.1 55 56 | |||
| prim::NumToTensor op_33 1 1 56 57 | |||
| aten::Int op_34 1 1 57 62 | |||
| prim::ListConstruct op_35 2 1 53 62 63 | |||
| aten::view op_36 2 1 input1.1 63 input2.1 | |||
| prim::Constant op_37 0 1 normalized value=%normalized | |||
| prim::Constant op_38 0 1 return_complex value=True | |||
| aten::stft op_39 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 | |||
| prim::Constant op_40 0 1 334 value=1 | |||
| aten::size op_41 2 1 spec_f.1 334 74 | |||
| prim::NumToTensor op_42 1 1 74 75 | |||
| aten::Int op_43 1 1 75 78 | |||
| prim::Constant op_44 0 1 335 value=2 | |||
| aten::size op_45 2 1 spec_f.1 335 80 | |||
| prim::NumToTensor op_46 1 1 80 81 | |||
| aten::Int op_47 1 1 81 87 | |||
| prim::ListConstruct op_48 3 1 16 78 87 88 | |||
| aten::reshape op_49 2 1 spec_f.1 88 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_5, 6) | |||
| class torchaudio_F_spectrogram_2 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 14 13 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=False center=%center pad=%pad pad_mode=%pad_mode | |||
| prim::Constant op_1 0 1 92 value=2.000000e+00 | |||
| aten::pow op_2 2 1 window 92 93 | |||
| prim::Constant op_3 0 1 127 value=None | |||
| aten::sum op_4 2 1 93 127 95 | |||
| aten::sqrt op_5 1 1 95 96 | |||
| aten::div op_6 2 1 spec 96 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = captured_params.at("pad"); | |||
| op->params["pad_mode"] = captured_params.at("pad_mode"); | |||
| op->params["center"] = captured_params.at("center"); | |||
| op->params["power"] = Parameter(); | |||
| op->params["normalized"] = "window"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_2, 7) | |||
| class torchaudio_F_spectrogram_3 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 9 8 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode | |||
| aten::abs op_1 1 1 spec out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = captured_params.at("pad"); | |||
| op->params["pad_mode"] = captured_params.at("pad_mode"); | |||
| op->params["center"] = captured_params.at("center"); | |||
| op->params["normalized"] = captured_params.at("normalized"); | |||
| op->params["power"] = 1; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_3, 8) | |||
| class torchaudio_F_spectrogram_4 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 10 9 | |||
| pnnx.Input input_0 0 1 waveform | |||
| pnnx.Input input_1 0 1 n_fft | |||
| pnnx.Input input_2 0 1 hop_length | |||
| pnnx.Input input_3 0 1 win_length | |||
| pnnx.Input input_4 0 1 window | |||
| pnnx.Input input_5 0 1 onesided | |||
| torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=1 normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode | |||
| prim::Constant op_1 0 1 391 value=2 | |||
| aten::pow op_2 2 1 spec 391 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torchaudio.functional.spectrogram"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["pad"] = captured_params.at("pad"); | |||
| op->params["pad_mode"] = captured_params.at("pad_mode"); | |||
| op->params["center"] = captured_params.at("center"); | |||
| op->params["normalized"] = captured_params.at("normalized"); | |||
| op->params["power"] = 2; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_4, 9) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,203 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_istft : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| torch.view_as_complex op_0 1 1 input a | |||
| torch.istft op_1 1 1 a out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length window=None | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "InverseSpectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "istft"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 1; // returns | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = 0; // all ones | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft, 20) | |||
| class torch_istft_1 : public torch_istft | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| torch.view_as_complex op_0 1 1 input a | |||
| torch.istft op_1 1 1 a b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None | |||
| torch.view_as_real op_2 1 1 b out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| torch_istft::write(op, captured_params); | |||
| op->params["1"] = 0; // returns | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_1, 20) | |||
| static bool NearlyEqual(float a, float b, float epsilon) | |||
| { | |||
| if (a == b) | |||
| return true; | |||
| float diff = (float)fabs(a - b); | |||
| if (diff <= epsilon) | |||
| return true; | |||
| // relative error | |||
| return diff < epsilon * std::max(fabs(a), fabs(b)); | |||
| } | |||
| static int detect_window_type(const std::vector<float>& window_data) | |||
| { | |||
| const int winlen = (int)window_data.size(); | |||
| bool is_one = true; | |||
| bool is_hann = true; | |||
| bool is_hamming = true; | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| if (!NearlyEqual(window_data[i], 1.f, 0.001)) | |||
| is_one = false; | |||
| if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) | |||
| is_hann = false; | |||
| if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) | |||
| is_hamming = false; | |||
| } | |||
| if (is_one) | |||
| return 0; | |||
| if (is_hann) | |||
| return 1; | |||
| if (is_hamming) | |||
| return 2; | |||
| return -1; | |||
| } | |||
| class torch_istft_2 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| torch.view_as_complex op_0 1 1 input a | |||
| pnnx.Attribute op_1 0 1 window @data | |||
| torch.istft op_2 2 1 a window out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "InverseSpectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "istft"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| return window_type != -1; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 1; // returns | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = window_type; | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_2, 20) | |||
| class torch_istft_3 : public torch_istft_2 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input 0 1 input | |||
| torch.view_as_complex op_0 1 1 input a | |||
| pnnx.Attribute op_1 0 1 window @data | |||
| torch.istft op_2 2 1 a window b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length | |||
| torch.view_as_real op_3 1 1 b out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| torch_istft_2::write(op, captured_params, captured_attrs); | |||
| op->params["1"] = 0; // returns | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_3, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,176 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_stft : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| torch.stft op_0 1 1 input a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None | |||
| torch.view_as_real op_1 1 1 a out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Spectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "stft"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::string& pad_mode = captured_params.at("pad_mode").s; | |||
| int pad_type = 2; | |||
| if (pad_mode == "constant") | |||
| pad_type = 0; | |||
| if (pad_mode == "replicate") | |||
| pad_type = 1; | |||
| if (pad_mode == "reflect") | |||
| pad_type = 2; | |||
| const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 0; // power | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = 0; // all ones | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["6"] = pad_type; | |||
| op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; | |||
| op->params["8"] = onesided; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft, 20) | |||
| static bool NearlyEqual(float a, float b, float epsilon) | |||
| { | |||
| if (a == b) | |||
| return true; | |||
| float diff = (float)fabs(a - b); | |||
| if (diff <= epsilon) | |||
| return true; | |||
| // relative error | |||
| return diff < epsilon * std::max(fabs(a), fabs(b)); | |||
| } | |||
| static int detect_window_type(const std::vector<float>& window_data) | |||
| { | |||
| const int winlen = (int)window_data.size(); | |||
| bool is_one = true; | |||
| bool is_hann = true; | |||
| bool is_hamming = true; | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| if (!NearlyEqual(window_data[i], 1.f, 0.001)) | |||
| is_one = false; | |||
| if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) | |||
| is_hann = false; | |||
| if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) | |||
| is_hamming = false; | |||
| } | |||
| if (is_one) | |||
| return 0; | |||
| if (is_hann) | |||
| return 1; | |||
| if (is_hamming) | |||
| return 2; | |||
| return -1; | |||
| } | |||
| class torch_stft_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_0 0 1 window @data | |||
| torch.stft op_1 2 1 input window a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length | |||
| torch.view_as_real op_2 1 1 a out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Spectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "stft"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| return window_type != -1; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| const std::string& pad_mode = captured_params.at("pad_mode").s; | |||
| int pad_type = 2; | |||
| if (pad_mode == "constant") | |||
| pad_type = 0; | |||
| if (pad_mode == "replicate") | |||
| pad_type = 1; | |||
| if (pad_mode == "reflect") | |||
| pad_type = 2; | |||
| const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 0; // power | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = window_type; | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["6"] = pad_type; | |||
| op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; | |||
| op->params["8"] = onesided; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft_1, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,127 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| static bool NearlyEqual(float a, float b, float epsilon) | |||
| { | |||
| if (a == b) | |||
| return true; | |||
| float diff = (float)fabs(a - b); | |||
| if (diff <= epsilon) | |||
| return true; | |||
| // relative error | |||
| return diff < epsilon * std::max(fabs(a), fabs(b)); | |||
| } | |||
| static int detect_window_type(const std::vector<float>& window_data) | |||
| { | |||
| const int winlen = (int)window_data.size(); | |||
| bool is_one = true; | |||
| bool is_hann = true; | |||
| bool is_hamming = true; | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| if (!NearlyEqual(window_data[i], 1.f, 0.001)) | |||
| is_one = false; | |||
| if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) | |||
| is_hann = false; | |||
| if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) | |||
| is_hamming = false; | |||
| } | |||
| if (is_one) | |||
| return 0; | |||
| if (is_hann) | |||
| return 1; | |||
| if (is_hamming) | |||
| return 2; | |||
| return -1; | |||
| } | |||
| class torchaudio_F_inverse_spectrogram : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_0 0 1 window @data | |||
| torch.view_as_complex op_1 1 1 input a | |||
| torchaudio.functional.inverse_spectrogram op_2 2 1 a window out center=%center hop_length=%hop_length length=None n_fft=%n_fft normalized=%normalized onesided=%onesided pad=0 win_length=%win_length | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "InverseSpectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "inverse_spectrogram"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| return window_type != -1; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| int normalized = 0; | |||
| if (captured_params.at("normalized").type == 1) | |||
| { | |||
| normalized = captured_params.at("normalized").b ? 2 : 0; | |||
| } | |||
| if (captured_params.at("normalized").type == 4) | |||
| { | |||
| if (captured_params.at("normalized").s == "frame_length") | |||
| normalized = 1; | |||
| if (captured_params.at("normalized").s == "window") | |||
| normalized = 2; | |||
| } | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 1; // returns | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = window_type; | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["7"] = normalized; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,233 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| static bool NearlyEqual(float a, float b, float epsilon) | |||
| { | |||
| if (a == b) | |||
| return true; | |||
| float diff = (float)fabs(a - b); | |||
| if (diff <= epsilon) | |||
| return true; | |||
| // relative error | |||
| return diff < epsilon * std::max(fabs(a), fabs(b)); | |||
| } | |||
| static int detect_window_type(const std::vector<float>& window_data) | |||
| { | |||
| const int winlen = (int)window_data.size(); | |||
| bool is_one = true; | |||
| bool is_hann = true; | |||
| bool is_hamming = true; | |||
| for (int i = 0; i < winlen; i++) | |||
| { | |||
| if (!NearlyEqual(window_data[i], 1.f, 0.001)) | |||
| is_one = false; | |||
| if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) | |||
| is_hann = false; | |||
| if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) | |||
| is_hamming = false; | |||
| } | |||
| if (is_one) | |||
| return 0; | |||
| if (is_hann) | |||
| return 1; | |||
| if (is_hamming) | |||
| return 2; | |||
| return -1; | |||
| } | |||
| class torchaudio_F_spectrogram : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_0 0 1 window @data | |||
| torchaudio.functional.spectrogram op_1 2 1 input window a n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode | |||
| torch.view_as_real op_2 1 1 a out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Spectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "spectrogram"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| if (captured_params.at("power").type != 0) | |||
| return false; | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| return window_type != -1; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| const std::string& pad_mode = captured_params.at("pad_mode").s; | |||
| int pad_type = 2; | |||
| if (pad_mode == "constant") | |||
| pad_type = 0; | |||
| if (pad_mode == "replicate") | |||
| pad_type = 1; | |||
| if (pad_mode == "reflect") | |||
| pad_type = 2; | |||
| const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; | |||
| int normalized = 0; | |||
| if (captured_params.at("normalized").type == 1) | |||
| { | |||
| normalized = captured_params.at("normalized").b ? 2 : 0; | |||
| } | |||
| if (captured_params.at("normalized").type == 4) | |||
| { | |||
| if (captured_params.at("normalized").s == "frame_length") | |||
| normalized = 1; | |||
| if (captured_params.at("normalized").s == "window") | |||
| normalized = 2; | |||
| } | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = 0; // power | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = window_type; | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["6"] = pad_type; | |||
| op->params["7"] = normalized; | |||
| op->params["8"] = onesided; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 20) | |||
| class torchaudio_F_spectrogram_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_0 0 1 window @data | |||
| torchaudio.functional.spectrogram op_1 2 1 input window out n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Spectrogram"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "spectrogram"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| if (captured_params.at("power").type == 0) | |||
| return false; | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| return window_type != -1; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data(); | |||
| const int window_type = detect_window_type(window_data); | |||
| const std::string& pad_mode = captured_params.at("pad_mode").s; | |||
| int pad_type = 2; | |||
| if (pad_mode == "constant") | |||
| pad_type = 0; | |||
| if (pad_mode == "replicate") | |||
| pad_type = 1; | |||
| if (pad_mode == "reflect") | |||
| pad_type = 2; | |||
| const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; | |||
| int normalized = 0; | |||
| if (captured_params.at("normalized").type == 1) | |||
| { | |||
| normalized = captured_params.at("normalized").b ? 2 : 0; | |||
| } | |||
| if (captured_params.at("normalized").type == 4) | |||
| { | |||
| if (captured_params.at("normalized").s == "frame_length") | |||
| normalized = 1; | |||
| if (captured_params.at("normalized").s == "window") | |||
| normalized = 2; | |||
| } | |||
| int power = 0; | |||
| if (captured_params.at("power").type == 2) | |||
| { | |||
| power = captured_params.at("power").i; | |||
| if (power != 1 && power != 2) | |||
| fprintf(stderr, "unsupported spectrogram power %d\n", power); | |||
| } | |||
| if (captured_params.at("power").type == 3) | |||
| { | |||
| if (NearlyEqual(captured_params.at("power").f, 1.0, 0.0001)) | |||
| power = 1; | |||
| else if (NearlyEqual(captured_params.at("power").f, 2.0, 0.0001)) | |||
| power = 2; | |||
| else | |||
| fprintf(stderr, "unsupported spectrogram power %f\n", captured_params.at("power").f); | |||
| } | |||
| op->params["0"] = captured_params.at("n_fft"); | |||
| op->params["1"] = power; | |||
| op->params["2"] = captured_params.at("hop_length"); | |||
| op->params["3"] = captured_params.at("win_length"); | |||
| op->params["4"] = window_type; | |||
| op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; | |||
| op->params["6"] = pad_type; | |||
| op->params["7"] = normalized; | |||
| op->params["8"] = onesided; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -360,6 +360,11 @@ if(TorchVision_FOUND) | |||
| pnnx_add_test(torchvision_RoIAlign) | |||
| endif() | |||
| pnnx_add_test(torchaudio_F_inverse_spectrogram) | |||
| pnnx_add_test(torchaudio_F_spectrogram) | |||
| pnnx_add_test(torchaudio_InverseSpectrogram) | |||
| pnnx_add_test(torchaudio_Spectrogram) | |||
| add_subdirectory(ncnn) | |||
| if(onnxruntime_FOUND) | |||
| @@ -175,6 +175,9 @@ pnnx_ncnn_add_test(torch_transpose) | |||
| pnnx_ncnn_add_test(torch_unbind) | |||
| pnnx_ncnn_add_test(torch_unsqueeze) | |||
| pnnx_ncnn_add_test(torch_istft) | |||
| pnnx_ncnn_add_test(torch_stft) | |||
| pnnx_ncnn_add_test(torch_abs) | |||
| pnnx_ncnn_add_test(torch_acos) | |||
| pnnx_ncnn_add_test(torch_asin) | |||
| @@ -217,3 +220,8 @@ pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast) | |||
| if(TorchVision_FOUND) | |||
| pnnx_ncnn_add_test(torchvision_DeformConv2d) | |||
| endif() | |||
| pnnx_ncnn_add_test(torchaudio_F_inverse_spectrogram) | |||
| pnnx_ncnn_add_test(torchaudio_F_spectrogram) | |||
| pnnx_ncnn_add_test(torchaudio_InverseSpectrogram) | |||
| pnnx_ncnn_add_test(torchaudio_Spectrogram) | |||
| @@ -0,0 +1,68 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| x = torch.view_as_complex(x) | |||
| y = torch.view_as_complex(y) | |||
| z = torch.view_as_complex(z) | |||
| w = torch.view_as_complex(w) | |||
| out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) | |||
| out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) | |||
| out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) | |||
| out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) | |||
| out3 = torch.view_as_real(out3) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(33, 161, 2) | |||
| y = torch.rand(65, 77, 2) | |||
| z = torch.rand(257, 8, 2) | |||
| w = torch.rand(512, 4, 2) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torch_istft.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_istft.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[512,4,2]") | |||
| # ncnn inference | |||
| import test_torch_istft_ncnn | |||
| b = test_torch_istft_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,65 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) | |||
| out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) | |||
| out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) | |||
| out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True) | |||
| out0 = torch.view_as_real(out0) | |||
| out1 = torch.view_as_real(out1) | |||
| out2 = torch.view_as_real(out2) | |||
| out3 = torch.view_as_real(out3) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(2560) | |||
| y = torch.rand(1000) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_stft.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_stft.pt inputshape=[2560],[1000]") | |||
| # ncnn inference | |||
| import test_torch_stft_ncnn | |||
| b = test_torch_stft_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,72 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| x = torch.view_as_complex(x) | |||
| y = torch.view_as_complex(y) | |||
| z = torch.view_as_complex(z) | |||
| w = torch.view_as_complex(w) | |||
| out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None) | |||
| out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None) | |||
| out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None) | |||
| out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=1024, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False, length=None) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| if version.parse(torchaudio.__version__) < version.parse('0.10.0'): | |||
| return True | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(33, 161, 2) | |||
| y = torch.rand(65, 77, 2) | |||
| z = torch.rand(257, 8, 2) | |||
| w = torch.rand(513, 4, 2) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torchaudio_F_inverse_spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]") | |||
| # ncnn inference | |||
| import test_torchaudio_F_inverse_spectrogram_ncnn | |||
| b = test_torchaudio_F_inverse_spectrogram_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) | |||
| out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) | |||
| out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) | |||
| out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) | |||
| out1 = torch.view_as_real(out1) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(2560) | |||
| y = torch.rand(1000) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torchaudio_F_spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[2560],[1000]") | |||
| # ncnn inference | |||
| import test_torchaudio_F_spectrogram_ncnn | |||
| b = test_torchaudio_F_spectrogram_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,77 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window') | |||
| self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False) | |||
| self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length') | |||
| self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=1024, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False) | |||
| def forward(self, x, y, z, w): | |||
| x = torch.view_as_complex(x) | |||
| y = torch.view_as_complex(y) | |||
| z = torch.view_as_complex(z) | |||
| w = torch.view_as_complex(w) | |||
| out0 = self.s0(x) | |||
| out1 = self.s1(y) | |||
| out2 = self.s2(z) | |||
| out3 = self.s3(w) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| if version.parse(torchaudio.__version__) < version.parse('0.10.0'): | |||
| return True | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(33, 161, 2) | |||
| y = torch.rand(65, 77, 2) | |||
| z = torch.rand(257, 8, 2) | |||
| w = torch.rand(513, 4, 2) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torchaudio_InverseSpectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]") | |||
| # ncnn inference | |||
| import test_torchaudio_InverseSpectrogram_ncnn | |||
| b = test_torchaudio_InverseSpectrogram_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,68 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) | |||
| self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) | |||
| self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) | |||
| self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) | |||
| def forward(self, x, y): | |||
| out0 = self.s0(x) | |||
| out1 = self.s1(x) | |||
| out2 = self.s2(y) | |||
| out3 = self.s3(y) | |||
| out1 = torch.view_as_real(out1) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(2560) | |||
| y = torch.rand(1000) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torchaudio_Spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[2560],[1000]") | |||
| # ncnn inference | |||
| import test_torchaudio_Spectrogram_ncnn | |||
| b = test_torchaudio_Spectrogram_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-3, 1e-3): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -21,9 +21,9 @@ class Model(nn.Module): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| out0 = torch.istft(x, n_fft=64, center=True, normalized=True, return_complex=False) | |||
| out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) | |||
| out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) | |||
| out2 = torch.istft(z, n_fft=512, center=True, onesided=True, return_complex=False) | |||
| out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) | |||
| out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) | |||
| return out0, out1, out2, out3 | |||
| @@ -52,7 +52,7 @@ def test(): | |||
| b = test_torch_istft_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| @@ -21,10 +21,10 @@ class Model(nn.Module): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| out0 = torch.stft(x, n_fft=64, center=True, pad_mode='reflect', normalized=True, return_complex=True) | |||
| out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) | |||
| out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) | |||
| out2 = torch.stft(y, n_fft=512, center=True, pad_mode='constant', onesided=True, return_complex=True) | |||
| out3 = torch.stft(y, n_fft=512, center=False, onesided=False, return_complex=True) | |||
| out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) | |||
| out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| @@ -50,7 +50,7 @@ def test(): | |||
| b = test_torch_stft_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| @@ -0,0 +1,68 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None) | |||
| out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None) | |||
| out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None) | |||
| out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False, length=None) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| if version.parse(torchaudio.__version__) < version.parse('0.10.0'): | |||
| return True | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 33, 161, dtype=torch.complex64) | |||
| y = torch.rand(1, 65, 77, dtype=torch.complex64) | |||
| z = torch.rand(257, 8, dtype=torch.complex64) | |||
| w = torch.rand(512, 4, dtype=torch.complex64) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torchaudio_F_inverse_spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64") | |||
| # pnnx inference | |||
| import test_torchaudio_F_inverse_spectrogram_pnnx | |||
| b = test_torchaudio_F_inverse_spectrogram_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,62 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) | |||
| out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) | |||
| out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) | |||
| out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 2560) | |||
| y = torch.rand(1000) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torchaudio_F_spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[3,2560],[1000]") | |||
| # pnnx inference | |||
| import test_torchaudio_F_spectrogram_pnnx | |||
| b = test_torchaudio_F_spectrogram_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,73 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window') | |||
| self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False) | |||
| self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length') | |||
| self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False) | |||
| def forward(self, x, y, z, w): | |||
| out0 = self.s0(x) | |||
| out1 = self.s1(y) | |||
| out2 = self.s2(z) | |||
| out3 = self.s3(w) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| if version.parse(torchaudio.__version__) < version.parse('0.10.0'): | |||
| return True | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 33, 161, dtype=torch.complex64) | |||
| y = torch.rand(1, 65, 77, dtype=torch.complex64) | |||
| z = torch.rand(257, 8, dtype=torch.complex64) | |||
| w = torch.rand(512, 4, dtype=torch.complex64) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torchaudio_InverseSpectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64") | |||
| # pnnx inference | |||
| import test_torchaudio_InverseSpectrogram_pnnx | |||
| b = test_torchaudio_InverseSpectrogram_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,67 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchaudio | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) | |||
| self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) | |||
| self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) | |||
| self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) | |||
| def forward(self, x, y): | |||
| out0 = self.s0(x) | |||
| out1 = self.s1(x) | |||
| out2 = self.s2(y) | |||
| out3 = self.s3(y) | |||
| return out0, out1, out2, out3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 2560) | |||
| y = torch.rand(1000) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torchaudio_Spectrogram.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[3,2560],[1000]") | |||
| # pnnx inference | |||
| import test_torchaudio_Spectrogram_pnnx | |||
| b = test_torchaudio_Spectrogram_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||