Browse Source

spectrogram and inverse spectrogram (#5779)

* only supports hann, hamming and all-one window
* inverse spectrogram does not support length parameter
* spectrogram always returns torch.view_as_real(out) as ncnn does not support complex typed mat yet
* inverse spectrogram always accepts torch.view_as_complex(in) as ncnn does not support complex typed mat yet
tags/20241226
nihui GitHub 1 year ago
parent
commit
0734b657d9
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
33 changed files with 3155 additions and 22 deletions
  1. +13
    -1
      .ci/pnnx.yml
  2. +51
    -0
      docs/developer-guide/operators.md
  3. +2
    -0
      src/CMakeLists.txt
  4. +238
    -0
      src/layer/inversespectrogram.cpp
  5. +45
    -0
      src/layer/inversespectrogram.h
  6. +221
    -0
      src/layer/spectrogram.cpp
  7. +47
    -0
      src/layer/spectrogram.h
  8. +2
    -0
      tests/CMakeLists.txt
  9. +56
    -0
      tests/test_inversespectrogram.cpp
  10. +58
    -0
      tests/test_spectrogram.cpp
  11. +7
    -0
      tools/pnnx/src/CMakeLists.txt
  12. +1
    -0
      tools/pnnx/src/ir.cpp
  13. +98
    -14
      tools/pnnx/src/pass_level2/torch_stft.cpp
  14. +165
    -0
      tools/pnnx/src/pass_level2/torchaudio_F_inverse_spectrogram.cpp
  15. +709
    -0
      tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp
  16. +203
    -0
      tools/pnnx/src/pass_ncnn/torch_istft.cpp
  17. +176
    -0
      tools/pnnx/src/pass_ncnn/torch_stft.cpp
  18. +127
    -0
      tools/pnnx/src/pass_ncnn/torchaudio_F_inverse_spectrogram.cpp
  19. +233
    -0
      tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp
  20. +5
    -0
      tools/pnnx/tests/CMakeLists.txt
  21. +8
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  22. +68
    -0
      tools/pnnx/tests/ncnn/test_torch_istft.py
  23. +65
    -0
      tools/pnnx/tests/ncnn/test_torch_stft.py
  24. +72
    -0
      tools/pnnx/tests/ncnn/test_torchaudio_F_inverse_spectrogram.py
  25. +63
    -0
      tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py
  26. +77
    -0
      tools/pnnx/tests/ncnn/test_torchaudio_InverseSpectrogram.py
  27. +68
    -0
      tools/pnnx/tests/ncnn/test_torchaudio_Spectrogram.py
  28. +3
    -3
      tools/pnnx/tests/test_torch_istft.py
  29. +4
    -4
      tools/pnnx/tests/test_torch_stft.py
  30. +68
    -0
      tools/pnnx/tests/test_torchaudio_F_inverse_spectrogram.py
  31. +62
    -0
      tools/pnnx/tests/test_torchaudio_F_spectrogram.py
  32. +73
    -0
      tools/pnnx/tests/test_torchaudio_InverseSpectrogram.py
  33. +67
    -0
      tools/pnnx/tests/test_torchaudio_Spectrogram.py

+ 13
- 1
.ci/pnnx.yml View File

@@ -31,39 +31,51 @@ jobs:
include:
- torch-version: 1.8.1
torchvision-version: 0.9.1
torchaudio-version: 0.8.1

- torch-version: 1.9.1
torchvision-version: 0.10.1
torchaudio-version: 0.9.1

- torch-version: 1.10.0
torchvision-version: 0.11.1
torchaudio-version: '0.10.0+cpu'

- torch-version: 1.11.0
torchvision-version: 0.12.0
torchaudio-version: '0.11.0+cpu'

- torch-version: 1.12.0
torchvision-version: 0.13.0
torchaudio-version: '0.12.0+cpu'

- torch-version: 1.13.0
torchvision-version: 0.14.0
torchaudio-version: '0.13.0+cpu'

- torch-version: 2.0.0
torchvision-version: 0.15.1
torchaudio-version: '2.0.0+cpu'

- torch-version: 2.1.0
torchvision-version: 0.16.0
torchaudio-version: '2.1.0+cpu'

- torch-version: 2.2.1
torchvision-version: 0.17.1
torchaudio-version: '2.2.1+cpu'

- torch-version: 2.3.0
torchvision-version: 0.18.0
torchaudio-version: '2.3.0+cpu'

- torch-version: 2.4.0
torchvision-version: 0.19.0
torchaudio-version: '2.4.0+cpu'

- torch-version: 2.5.0
torchvision-version: 0.20.0
torchaudio-version: '2.5.0+cpu'

runs-on:
pool-name: docker
@@ -169,7 +181,7 @@ jobs:
- name: setup-pytorch
run: |
export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}}
pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu --index-url https://download.pytorch.org/whl/cpu
pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} --index-url https://download.pytorch.org/whl/cpu
pip3 install --user onnx
pip3 install --user onnxscript



+ 51
- 0
docs/developer-guide/operators.md View File

@@ -46,6 +46,7 @@
* [Input](#input)
* [InstanceNorm](#instancenorm)
* [Interp](#interp)
* [InverseSpectrogram](#inversespectrogram)
* [LayerNorm](#layernorm)
* [Log](#log)
* [LRN](#lrn)
@@ -81,6 +82,7 @@
* [Slice](#slice)
* [Softmax](#softmax)
* [Softplus](#softplus)
* [Spectrogram](#spectrogram)
* [Split](#split)
* [Swish](#swish)
* [TanH](#tanh)
@@ -1141,6 +1143,30 @@ Resize type:
- 2 = Bilinear
- 3 = Bicubic

# InverseSpectrogram
```
x1 = x as complex
x1 = x1 * sqrt(norm) if normalized
y = istft(x1)
y1 = unpad(y) if center

if returns == 0 return y1 as complex
if returns == 1 return y1 real
if returns == 2 return y1 imag
```

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | n_fft | int | 0 | |
| 1 | returns | int | 1 | |
| 2 | hoplen | int | n_fft / 4 | |
| 3 | winlen | int | n_fft | |
| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming |
| 5 | center | int | 1 | |
| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy |

# LayerNorm
```
split x along outmost axis into part x0, x1 ...
@@ -1829,6 +1855,31 @@ y = log(exp(x) + 1)
* one_blob_only
* support_inplace

# Spectrogram
```
x1 = pad(x) if center
y = stft(x1)
y = y / sqrt(norm) if normalized

if power == 0 return y as real
if power == 1 return magnitude
if power == 2 return square of magnitude
```

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | n_fft | int | 0 | |
| 1 | power | int | 0 | |
| 2 | hoplen | int | n_fft / 4 | |
| 3 | winlen | int | n_fft | |
| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming |
| 5 | center | int | 1 | |
| 6 | pad_type | int | 2 | 0=CONSTANT 1=REPLICATE 2=REFLECT |
| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy |
| 8 | onesided | int | 1 | |

# Split
```
y0, y1 ... = x


+ 2
- 0
src/CMakeLists.txt View File

@@ -167,6 +167,8 @@ ncnn_add_layer(Diag)
ncnn_add_layer(CELU)
ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)
ncnn_add_layer(Spectrogram)
ncnn_add_layer(InverseSpectrogram)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)


+ 238
- 0
src/layer/inversespectrogram.cpp View File

@@ -0,0 +1,238 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "inversespectrogram.h"

namespace ncnn {

InverseSpectrogram::InverseSpectrogram()
{
one_blob_only = true;
support_inplace = false;
}

int InverseSpectrogram::load_param(const ParamDict& pd)
{
n_fft = pd.get(0, 0);
returns = pd.get(1, 0);
hoplen = pd.get(2, n_fft / 4);
winlen = pd.get(3, n_fft);
window_type = pd.get(4, 0);
center = pd.get(5, 1);
normalized = pd.get(7, 0);

// assert winlen <= n_fft
// generate window
window_data.create(normalized == 2 ? n_fft + 1 : n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}
if (window_type == 0)
{
// all ones
for (int i = 0; i < winlen; i++)
{
*p++ = 1.f;
}
}
if (window_type == 1)
{
// hann window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen));
}
}
if (window_type == 2)
{
// hamming window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen);
}
}
for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}

// pre-calculated window norm factor
if (normalized == 2)
{
float sqsum = 0.f;
for (int i = 0; i < n_fft; i++)
{
sqsum += window_data[i] * window_data[i];
}
window_data[n_fft] = sqrt(sqsum);
}
}

return 0;
}

int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
// https://github.com/librosa/librosa/blob/main/librosa/core/spectrum.py#L630

// TODO custom window
// TODO output length

const int frames = bottom_blob.h;
const int freqs = bottom_blob.c;
// assert freqs == n_fft or freqs == n_fft / 2 + 1

const int onesided = freqs == n_fft / 2 + 1 ? 1 : 0;

const int outsize = center ? (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2) : (frames - 1) * hoplen + n_fft;

const size_t elemsize = bottom_blob.elemsize;

if (returns == 0)
{
top_blob.create(2, outsize, elemsize, opt.blob_allocator);
}
else
{
top_blob.create(outsize, elemsize, opt.blob_allocator);
}
if (top_blob.empty())
return -100;

Mat window_sumsquare(outsize + n_fft, elemsize, opt.workspace_allocator);
if (window_sumsquare.empty())
return -100;

top_blob.fill(0.f);
window_sumsquare.fill(0.f);

for (int j = 0; j < frames; j++)
{
// collect complex
Mat sp(2, n_fft);
if (onesided == 1)
{
for (int k = 0; k < n_fft / 2 + 1; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
for (int k = n_fft / 2 + 1; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(n_fft - k).row(j)[0];
sp.row(k)[1] = -bottom_blob.channel(n_fft - k).row(j)[1];
}
}
else
{
for (int k = 0; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
}

if (normalized == 1)
{
float norm = sqrt(n_fft);
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= norm;
}
}
if (normalized == 2)
{
float norm = window_data[n_fft];
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= norm;
}
}

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < n_fft; i++)
{
// inverse dft
float re = 0.f;
float im = 0.f;
for (int k = 0; k < n_fft; k++)
{
double angle = 2 * 3.14159265358979323846 * i * k / n_fft;

re += sp.row(k)[0] * cosf(angle) - sp.row(k)[1] * sinf(angle);
im += sp.row(k)[0] * sinf(angle) + sp.row(k)[1] * cosf(angle);
}

re /= n_fft;
im /= n_fft;

// apply window
re *= window_data[i];
im *= window_data[i];

int output_index = j * hoplen + i;
if (center == 1)
{
output_index -= n_fft / 2;
}
if (output_index >= 0 && output_index < outsize)
{
// square window
window_sumsquare[output_index] += window_data[i] * window_data[i];

if (returns == 0)
{
top_blob.row(output_index)[0] += re;
top_blob.row(output_index)[1] += im;
}
if (returns == 1)
{
top_blob[output_index] += re;
}
if (returns == 2)
{
top_blob[output_index] += im;
}
}
}
}

// square window norm
if (returns == 0)
{
for (int i = 0; i < outsize; i++)
{
if (window_sumsquare[i] != 0.f)
{
top_blob.row(i)[0] /= window_sumsquare[i];
top_blob.row(i)[1] /= window_sumsquare[i];
}
}
}
else
{
for (int i = 0; i < outsize; i++)
{
if (window_sumsquare[i] != 0.f)
top_blob[i] /= window_sumsquare[i];
}
}

return 0;
}

} // namespace ncnn

+ 45
- 0
src/layer/inversespectrogram.h View File

@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_INVERSESPECTROGRAM_H
#define LAYER_INVERSESPECTROGRAM_H

#include "layer.h"

namespace ncnn {

class InverseSpectrogram : public Layer
{
public:
InverseSpectrogram();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
int n_fft;
int returns; // 0=complex 1=real 2=imag
int hoplen;
int winlen;
int window_type; // 0=ones 1=hann 2=hamming
int center;
int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy

Mat window_data;
};

} // namespace ncnn

#endif // LAYER_INVERSESPECTROGRAM_H

+ 221
- 0
src/layer/spectrogram.cpp View File

@@ -0,0 +1,221 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "spectrogram.h"

namespace ncnn {

Spectrogram::Spectrogram()
{
one_blob_only = true;
support_inplace = false;
}

int Spectrogram::load_param(const ParamDict& pd)
{
n_fft = pd.get(0, 0);
power = pd.get(1, 0);
hoplen = pd.get(2, n_fft / 4);
winlen = pd.get(3, n_fft);
window_type = pd.get(4, 0);
center = pd.get(5, 1);
pad_type = pd.get(6, 2);
normalized = pd.get(7, 0);
onesided = pd.get(8, 1);

// assert winlen <= n_fft
// generate window
window_data.create(normalized == 2 ? n_fft + 1 : n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}
if (window_type == 0)
{
// all ones
for (int i = 0; i < winlen; i++)
{
*p++ = 1.f;
}
}
if (window_type == 1)
{
// hann window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen));
}
}
if (window_type == 2)
{
// hamming window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen);
}
}
for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}

// pre-calculated window norm factor
if (normalized == 2)
{
float sqsum = 0.f;
for (int i = 0; i < n_fft; i++)
{
sqsum += window_data[i] * window_data[i];
}
window_data[n_fft] = 1.f / sqrt(sqsum);
}
}

return 0;
}

int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
// https://pytorch.org/audio/stable/generated/torchaudio.functional.spectrogram.html

// TODO custom window

Mat bottom_blob_bordered = bottom_blob;
if (center == 1)
{
Option opt_b = opt;
opt_b.blob_allocator = opt.workspace_allocator;
if (pad_type == 0)
copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_CONSTANT, 0.f, opt_b);
if (pad_type == 1)
copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REPLICATE, 0.f, opt_b);
if (pad_type == 2)
copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REFLECT, 0.f, opt_b);
}

const int size = bottom_blob_bordered.w;

// const int frames = size / hoplen + 1;
const int frames = (size - n_fft) / hoplen + 1;
const int freqs_onesided = n_fft / 2 + 1;
const int freqs = onesided ? freqs_onesided : n_fft;

const size_t elemsize = bottom_blob_bordered.elemsize;

if (power == 0)
{
top_blob.create(2, frames, freqs, elemsize, opt.blob_allocator);
}
else
{
top_blob.create(frames, freqs, elemsize, opt.blob_allocator);
}
if (top_blob.empty())
return -100;

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < freqs_onesided; i++)
{
const float* ptr = bottom_blob_bordered;
float* outptr = power == 0 ? top_blob.channel(i) : top_blob.row(i);

for (int j = 0; j < frames; j++)
{
float re = 0.f;
float im = 0.f;
for (int k = 0; k < n_fft; k++)
{
float v = ptr[k];

// apply window
v *= window_data[k];

// dft
double angle = 2 * 3.14159265358979323846 * i * k / n_fft;

re += v * cosf(angle); // + imag * sinf(angle);
im -= v * sinf(angle); // + imag * cosf(angle);
}

if (normalized == 1)
{
float norm = 1.f / sqrt(n_fft);
re *= norm;
im *= norm;
}
if (normalized == 2)
{
float norm = window_data[n_fft];
re *= norm;
im *= norm;
}

if (power == 0)
{
// complex as real
outptr[0] = re;
outptr[1] = im;
outptr += 2;
}
if (power == 1)
{
// magnitude
outptr[0] = sqrt(re * re + im * im);
outptr += 1;
}
if (power == 2)
{
outptr[0] = re * re + im * im;
outptr += 1;
}

ptr += hoplen;
}
}

if (!onesided)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = freqs_onesided; i < n_fft; i++)
{
if (power == 0)
{
const float* ptr = top_blob.channel(n_fft - i);
float* outptr = top_blob.channel(i);

for (int j = 0; j < frames; j++)
{
// complex as real
outptr[0] = ptr[0];
outptr[1] = -ptr[1];
ptr += 2;
outptr += 2;
}
}
else // if (power == 1 || power == 2)
{
const float* ptr = top_blob.row(n_fft - i);
float* outptr = top_blob.row(i);

memcpy(outptr, ptr, frames * sizeof(float));
}
}
}

return 0;
}

} // namespace ncnn

+ 47
- 0
src/layer/spectrogram.h View File

@@ -0,0 +1,47 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_SPECTROGRAM_H
#define LAYER_SPECTROGRAM_H

#include "layer.h"

namespace ncnn {

class Spectrogram : public Layer
{
public:
Spectrogram();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
int n_fft;
int power;
int hoplen;
int winlen;
int window_type; // 0=ones 1=hann 2=hamming
int center;
int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT
int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy
int onesided;

Mat window_data;
};

} // namespace ncnn

#endif // LAYER_SPECTROGRAM_H

+ 2
- 0
tests/CMakeLists.txt View File

@@ -117,6 +117,7 @@ ncnn_add_layer_test(HardSwish)
ncnn_add_layer_test(InnerProduct)
ncnn_add_layer_test(InstanceNorm)
ncnn_add_layer_test(Interp)
ncnn_add_layer_test(InverseSpectrogram)
ncnn_add_layer_test(LayerNorm)
ncnn_add_layer_test(LRN)
ncnn_add_layer_test(LSTM)
@@ -154,6 +155,7 @@ ncnn_add_layer_test(Sigmoid)
ncnn_add_layer_test(Slice)
ncnn_add_layer_test(Softmax)
ncnn_add_layer_test(Softplus)
ncnn_add_layer_test(Spectrogram)
ncnn_add_layer_test(Squeeze)
ncnn_add_layer_test(Swish)
ncnn_add_layer_test(TanH)


+ 56
- 0
tests/test_inversespectrogram.cpp View File

@@ -0,0 +1,56 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_inversespectrogram(int frames, int freqs, int n_fft, int returns, int hoplen, int winlen, int window_type, int center, int normalized)
{
ncnn::Mat a = RandomMat(2, frames, freqs);

ncnn::ParamDict pd;
pd.set(0, n_fft);
pd.set(1, returns);
pd.set(2, hoplen);
pd.set(3, winlen);
pd.set(4, window_type);
pd.set(5, center);
pd.set(7, normalized);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("InverseSpectrogram", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_inversespectrogram failed frames=%d freqs=%d n_fft=%d returns=%d hoplen=%d winlen=%d window_type=%d center=%d normalized=%d\n", frames, freqs, n_fft, returns, hoplen, winlen, window_type, center, normalized);
}

return ret;
}

static int test_inversespectrogram_0()
{
return 0
|| test_inversespectrogram(17, 1, 1, 0, 1, 1, 0, 1, 0)
|| test_inversespectrogram(39, 9, 17, 0, 7, 15, 0, 0, 1)
|| test_inversespectrogram(128, 6, 10, 0, 2, 7, 1, 1, 1)
|| test_inversespectrogram(255, 17, 17, 1, 14, 17, 2, 0, 0)
|| test_inversespectrogram(124, 28, 55, 2, 12, 55, 1, 1, 2);
}

int main()
{
SRAND(7767517);

return test_inversespectrogram_0();
}

+ 58
- 0
tests/test_spectrogram.cpp View File

@@ -0,0 +1,58 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_spectrogram(int size, int n_fft, int power, int hoplen, int winlen, int window_type, int center, int pad_type, int normalized, int onesided)
{
ncnn::Mat a = RandomMat(size);

ncnn::ParamDict pd;
pd.set(0, n_fft);
pd.set(1, power);
pd.set(2, hoplen);
pd.set(3, winlen);
pd.set(4, window_type);
pd.set(5, center);
pd.set(6, pad_type);
pd.set(7, normalized);
pd.set(8, onesided);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("Spectrogram", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_spectrogram failed size=%d n_fft=%d power=%d hoplen=%d winlen=%d window_type=%d center=%d pad_type=%d normalized=%d onesided=%d\n", size, n_fft, power, hoplen, winlen, window_type, center, pad_type, normalized, onesided);
}

return ret;
}

static int test_spectrogram_0()
{
return 0
|| test_spectrogram(17, 1, 0, 1, 1, 0, 1, 0, 0, 0)
|| test_spectrogram(39, 17, 0, 7, 15, 0, 0, 0, 1, 0)
|| test_spectrogram(128, 10, 0, 2, 7, 1, 1, 1, 1, 1)
|| test_spectrogram(255, 17, 1, 14, 17, 2, 0, 0, 0, 1)
|| test_spectrogram(124, 55, 2, 12, 55, 1, 1, 2, 2, 0);
}

int main()
{
SRAND(7767517);

return test_spectrogram_0();
}

+ 7
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -306,6 +306,9 @@ set(pnnx_pass_level2_SRCS

pass_level2/nn_quantized_FloatFunctional.cpp

pass_level2/torchaudio_F_inverse_spectrogram.cpp
pass_level2/torchaudio_F_spectrogram.cpp

pass_level2/nn_GRU.cpp
pass_level2/nn_LSTM.cpp
pass_level2/nn_RNN.cpp
@@ -570,6 +573,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_cumsum.cpp
pass_ncnn/torch_diag.cpp
pass_ncnn/torch_flatten.cpp
pass_ncnn/torch_istft.cpp
pass_ncnn/torch_logsumexp.cpp
pass_ncnn/torch_matmul.cpp
pass_ncnn/torch_max.cpp
@@ -582,9 +586,12 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_slice_scatter.cpp
pass_ncnn/torch_squeeze.cpp
pass_ncnn/torch_sum.cpp
pass_ncnn/torch_stft.cpp
pass_ncnn/torch_t.cpp
pass_ncnn/torch_transpose.cpp
pass_ncnn/torch_unsqueeze.cpp
pass_ncnn/torchaudio_F_inverse_spectrogram.cpp
pass_ncnn/torchaudio_F_spectrogram.cpp
pass_ncnn/torchvision_DeformConv2d.cpp
)



+ 1
- 0
tools/pnnx/src/ir.cpp View File

@@ -1458,6 +1458,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, "import torch.nn.functional as F\n");
fprintf(pyfp, "try:\n");
fprintf(pyfp, " import torchvision\n");
fprintf(pyfp, " import torchaudio\n");
fprintf(pyfp, "except:\n");
fprintf(pyfp, " pass\n");



+ 98
- 14
tools/pnnx/src/pass_level2/torch_stft.cpp View File

@@ -43,6 +43,7 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["pad_mode"] = "reflect";
op->params["center"] = false;
}
};
@@ -55,7 +56,7 @@ public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
37 36
36 35
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
@@ -79,19 +80,18 @@ prim::Constant op_11 0 1 21 value=%pad_left
prim::Constant op_12 0 1 63 value=%pad_right
prim::ListConstruct op_13 2 1 21 63 22
prim::Constant op_14 0 1 23 value=%pad_mode
prim::Constant op_15 0 1 24 value=None
aten::pad op_16 4 1 a 22 23 24 b
prim::Constant op_17 0 1 64 value=1
aten::size op_18 2 1 b 64 27
prim::NumToTensor op_19 1 1 27 28
aten::Int op_20 1 1 28 31
prim::Constant op_21 0 1 33 value=2
aten::size op_22 2 1 b 33 34
prim::NumToTensor op_23 1 1 34 35
aten::Int op_24 1 1 35 40
prim::ListConstruct op_25 2 1 31 40 41
aten::view op_26 2 1 b 41 c
aten::stft op_27 8 1 c n_fft hop_length win_length window normalized onesided return_complex out
F.pad op_15 3 1 a 22 23 b
prim::Constant op_16 0 1 64 value=1
aten::size op_17 2 1 b 64 27
prim::NumToTensor op_18 1 1 27 28
aten::Int op_29 1 1 28 31
prim::Constant op_20 0 1 33 value=2
aten::size op_21 2 1 b 33 34
prim::NumToTensor op_22 1 1 34 35
aten::Int op_23 1 1 35 40
prim::ListConstruct op_24 2 1 31 40 41
aten::view op_25 2 1 b 41 c
aten::stft op_26 8 1 c n_fft hop_length win_length window normalized onesided return_complex out
pnnx.Output output 1 0 out
)PNNXIR";
}
@@ -110,4 +110,88 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_1, 19)

class torch_stft_2 : public torch_stft_1
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
29 28
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 normalized
pnnx.Input input_6 0 1 onesided
pnnx.Input input_7 0 1 return_complex
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 input 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 18
prim::Constant op_4 0 1 15 value=1
prim::Constant op_5 0 1 121 value=1
prim::ListConstruct op_6 3 1 15 121 18 19
aten::view op_7 2 1 input 19 a
prim::Constant op_8 0 1 22 value=%pad_left
prim::Constant op_9 0 1 122 value=%pad_right
prim::ListConstruct op_10 2 1 22 122 23
prim::Constant op_11 0 1 24 value=%pad_mode
F.pad op_12 3 1 a 23 24 b
prim::Constant op_13 0 1 28 value=2
aten::size op_14 2 1 b 28 29
prim::NumToTensor op_15 1 1 29 30
aten::Int op_16 1 1 30 34
prim::ListConstruct op_17 1 1 34 35
aten::view op_18 2 1 b 35 c
aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_2, 19)

class torch_stft_3 : public torch_stft_1
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
29 28
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 normalized
pnnx.Input input_6 0 1 onesided
pnnx.Input input_7 0 1 return_complex
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 input 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 18
prim::Constant op_4 0 1 15 value=1
prim::Constant op_5 0 1 121 value=1
prim::ListConstruct op_6 3 1 15 121 18 19
aten::view op_7 2 1 input 19 a
prim::Constant op_8 0 1 22 value=%pad_left
prim::Constant op_9 0 1 122 value=%pad_right
prim::ListConstruct op_10 2 1 22 122 23
prim::Constant op_11 0 1 24 value=None
F.pad op_12 3 1 a 23 24 b mode=%pad_mode
prim::Constant op_13 0 1 28 value=2
aten::size op_14 2 1 b 28 29
prim::NumToTensor op_15 1 1 29 30
aten::Int op_16 1 1 30 34
prim::ListConstruct op_17 1 1 34 35
aten::view op_18 2 1 b 35 c
aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_3, 19)

} // namespace pnnx

+ 165
- 0
tools/pnnx/src/pass_level2/torchaudio_F_inverse_spectrogram.cpp View File

@@ -0,0 +1,165 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torchaudio_F_inverse_spectrogram : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
29 28
pnnx.Input input_0 0 1 spectrogram
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 center
pnnx.Input input_6 0 1 onesided
prim::Constant op_0 0 1 13 value=0
aten::size op_1 2 1 spectrogram 13 14
prim::NumToTensor op_2 1 1 14 15
aten::Int op_3 1 1 15 18
prim::Constant op_4 0 1 20 value=1
aten::size op_5 2 1 spectrogram 20 21
prim::NumToTensor op_6 1 1 21 22
aten::Int op_7 1 1 22 28
prim::Constant op_8 0 1 24 value=-1
prim::ListConstruct op_9 3 1 24 18 28 29
aten::reshape op_10 2 1 spectrogram 29 spectrogram.1
prim::Constant op_11 0 1 normalized value=%normalized
prim::Constant op_12 0 1 length value=None
prim::Constant op_13 0 1 return_complex value=False
aten::istft op_14 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1
prim::Constant op_15 0 1 75 value=1
aten::size op_16 2 1 waveform.1 75 42
prim::NumToTensor op_17 1 1 42 43
aten::Int op_18 1 1 43 47
prim::ListConstruct op_19 1 1 47 48
aten::reshape op_20 2 1 waveform.1 48 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.inverse_spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["length"] = Parameter();
op->params["pad"] = 0;
if (captured_params.at("normalized").b)
{
op->params["normalized"] = "frame_length";
}
else
{
op->params["normalized"] = false;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 6)

class torchaudio_F_inverse_spectrogram_0 : public torchaudio_F_inverse_spectrogram
{
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
33 32
pnnx.Input input_0 0 1 spectrogram
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 center
pnnx.Input input_6 0 1 onesided
prim::Constant op_0 0 1 13 value=0
aten::size op_1 2 1 spectrogram 13 14
prim::NumToTensor op_2 1 1 14 15
aten::Int op_3 1 1 15 18
prim::Constant op_4 0 1 20 value=1
aten::size op_5 2 1 spectrogram 20 21
prim::NumToTensor op_6 1 1 21 22
aten::Int op_7 1 1 22 25
prim::Constant op_8 0 1 27 value=2
aten::size op_9 2 1 spectrogram 27 28
prim::NumToTensor op_10 1 1 28 29
aten::Int op_11 1 1 29 35
prim::Constant op_12 0 1 31 value=-1
prim::ListConstruct op_13 3 1 31 25 35 36
aten::reshape op_14 2 1 spectrogram 36 spectrogram.1
prim::Constant op_15 0 1 normalized value=%normalized
prim::Constant op_16 0 1 length value=None
prim::Constant op_17 0 1 return_complex value=False
aten::istft op_18 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1
prim::Constant op_19 0 1 83 value=1
aten::size op_20 2 1 waveform.1 83 49
prim::NumToTensor op_21 1 1 49 50
aten::Int op_22 1 1 50 55
prim::ListConstruct op_23 2 1 18 55 56
aten::reshape op_24 2 1 waveform.1 56 out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_0, 6)

class torchaudio_F_inverse_spectrogram_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input_0 0 1 spectrogram
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 center
pnnx.Input input_6 0 1 onesided
prim::Constant op_0 0 1 13 value=2.000000e+00
aten::pow op_1 2 1 window 13 14
prim::Constant op_2 0 1 87 value=None
aten::sum op_3 2 1 14 87 16
aten::sqrt op_4 1 1 16 17
aten::mul op_5 2 1 spectrogram 17 spectrogram.1
torchaudio.functional.inverse_spectrogram op_6 7 1 spectrogram.1 n_fft hop_length win_length window center onesided out normalized=False length=%length pad=%pad
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.inverse_spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["length"] = captured_params.at("length");
op->params["pad"] = captured_params.at("pad");
op->params["normalized"] = "window";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_1, 7)

} // namespace pnnx

+ 709
- 0
tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp View File

@@ -0,0 +1,709 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torchaudio_F_spectrogram : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
27 26
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 waveform 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 18
prim::Constant op_4 0 1 15 value=-1
prim::ListConstruct op_5 2 1 15 18 19
aten::reshape op_6 2 1 waveform 19 waveform.1
prim::Constant op_7 0 1 normalized value=%normalized
prim::Constant op_8 0 1 return_complex value=True
aten::stft op_9 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_10 0 1 29 value=1
aten::size op_11 2 1 spec_f.1 29 30
prim::NumToTensor op_12 1 1 30 31
aten::Int op_13 1 1 31 34
prim::Constant op_14 0 1 36 value=2
aten::size op_15 2 1 spec_f.1 36 37
prim::NumToTensor op_16 1 1 37 38
aten::Int op_17 1 1 38 43
prim::ListConstruct op_18 2 1 34 43 44
aten::reshape op_19 2 1 spec_f.1 44 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = 0;
op->params["pad_mode"] = "reflect";
op->params["center"] = false;
op->params["power"] = Parameter();
if (captured_params.at("normalized").b)
{
op->params["normalized"] = "frame_length";
}
else
{
op->params["normalized"] = false;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 6)

class torchaudio_F_spectrogram_0 : public torchaudio_F_spectrogram
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
31 30
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 waveform 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 16
prim::Constant op_4 0 1 18 value=1
aten::size op_5 2 1 waveform 18 19
prim::NumToTensor op_6 1 1 19 20
aten::Int op_7 1 1 20 25
prim::Constant op_8 0 1 22 value=-1
prim::ListConstruct op_9 2 1 22 25 26
aten::reshape op_10 2 1 waveform 26 waveform.1
prim::Constant op_11 0 1 normalized value=%normalized
prim::Constant op_12 0 1 return_complex value=True
aten::stft op_13 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_14 0 1 72 value=1
aten::size op_15 2 1 spec_f.1 72 36
prim::NumToTensor op_16 1 1 36 37
aten::Int op_17 1 1 37 40
prim::Constant op_18 0 1 42 value=2
aten::size op_19 2 1 spec_f.1 42 43
prim::NumToTensor op_20 1 1 43 44
aten::Int op_21 1 1 44 50
prim::ListConstruct op_22 3 1 16 40 50 51
aten::reshape op_23 2 1 spec_f.1 51 out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_0, 6)

class torchaudio_F_spectrogram_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
58 57
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 18 value=1
aten::size op_1 2 1 waveform 18 19
prim::NumToTensor op_2 1 1 19 20
aten::Int op_3 1 1 20 25
prim::Constant op_4 0 1 22 value=-1
prim::ListConstruct op_5 2 1 22 25 26
aten::reshape op_6 2 1 waveform 26 waveform.1
prim::Constant op_7 0 1 106 value=0
aten::size op_8 2 1 waveform.1 106 29
prim::NumToTensor op_9 1 1 29 30
aten::Int op_10 1 1 30 33
prim::Constant op_11 0 1 107 value=1
aten::size op_12 2 1 waveform.1 107 35
prim::NumToTensor op_13 1 1 35 36
aten::Int op_14 1 1 36 41
prim::Constant op_15 0 1 108 value=1
prim::ListConstruct op_16 3 1 108 33 41 42
aten::view op_17 2 1 waveform.1 42 input0.1
prim::Constant op_18 0 1 45 value=%pad_left
prim::Constant op_19 0 1 109 value=%pad_right
prim::ListConstruct op_20 2 1 45 109 46
prim::Constant op_21 0 1 47 value=%pad_mode
prim::Constant op_22 0 1 110 value=None
aten::pad op_23 4 1 input0.1 46 47 110 input1.1
prim::Constant op_24 0 1 111 value=1
aten::size op_25 2 1 input1.1 111 51
prim::NumToTensor op_26 1 1 51 52
aten::Int op_27 1 1 52 55
prim::Constant op_28 0 1 57 value=2
aten::size op_29 2 1 input1.1 57 58
prim::NumToTensor op_30 1 1 58 59
aten::Int op_31 1 1 59 64
prim::ListConstruct op_32 2 1 55 64 65
aten::view op_33 2 1 input1.1 65 input2.1
prim::Constant op_34 0 1 normalized value=%normalized
prim::Constant op_35 0 1 return_complex value=True
aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_37 0 1 11 value=0
aten::size op_38 2 1 waveform 11 12
prim::NumToTensor op_39 1 1 12 13
aten::Int op_40 1 1 13 16
prim::Constant op_41 0 1 116 value=1
aten::size op_42 2 1 spec_f.1 116 75
prim::NumToTensor op_43 1 1 75 76
aten::Int op_44 1 1 76 79
prim::Constant op_45 0 1 117 value=2
aten::size op_46 2 1 spec_f.1 117 81
prim::NumToTensor op_47 1 1 81 82
aten::Int op_48 1 1 82 88
prim::ListConstruct op_49 3 1 16 79 88 89
aten::reshape op_50 2 1 spec_f.1 89 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = 0;
op->params["pad_mode"] = captured_params.at("pad_mode");
op->params["center"] = true;
op->params["power"] = Parameter();
if (captured_params.at("normalized").b)
{
op->params["normalized"] = "frame_length";
}
else
{
op->params["normalized"] = false;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 6)

class torchaudio_F_spectrogram_1_1 : public torchaudio_F_spectrogram_1
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
63 62
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 waveform 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 18
prim::Constant op_4 0 1 15 value=-1
prim::ListConstruct op_5 2 1 15 18 19
aten::reshape op_6 2 1 waveform 19 waveform.1
prim::Constant op_7 0 1 108 value=0
aten::size op_8 2 1 waveform.1 108 22
prim::NumToTensor op_9 1 1 22 23
aten::Int op_10 1 1 23 26
prim::Constant op_11 0 1 28 value=1
aten::size op_12 2 1 waveform.1 28 29
prim::NumToTensor op_13 1 1 29 30
aten::Int op_14 1 1 30 35
prim::Constant op_15 0 1 109 value=1
prim::ListConstruct op_16 3 1 109 26 35 36
aten::view op_17 2 1 waveform.1 36 input0.1
prim::Constant op_18 0 1 39 value=%pad_left
prim::Constant op_19 0 1 110 value=%pad_right
prim::ListConstruct op_20 2 1 39 110 40
prim::Constant op_21 0 1 41 value=%pad_mode
prim::Constant op_22 0 1 111 value=None
aten::pad op_23 4 1 input0.1 40 41 111 input1.1
prim::Constant op_24 0 1 112 value=1
aten::size op_25 2 1 input1.1 112 45
prim::NumToTensor op_26 1 1 45 46
aten::Int op_27 1 1 46 49
prim::Constant op_28 0 1 51 value=2
aten::size op_29 2 1 input1.1 51 52
prim::NumToTensor op_30 1 1 52 53
aten::Int op_31 1 1 53 58
prim::ListConstruct op_32 2 1 49 58 59
aten::view op_33 2 1 input1.1 59 input2.1
prim::Constant op_34 0 1 normalized value=%normalized
prim::Constant op_35 0 1 return_complex value=True
aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_37 0 1 117 value=1
aten::size op_38 2 1 spec_f.1 117 69
prim::NumToTensor op_39 1 1 69 70
aten::Int op_40 1 1 70 73
prim::Constant op_50 0 1 118 value=2
aten::size op_51 2 1 spec_f.1 118 75
prim::NumToTensor op_52 1 1 75 76
aten::Int op_53 1 1 76 81
prim::ListConstruct op_54 2 1 73 81 82
aten::reshape op_55 2 1 spec_f.1 82 out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_1, 6)

class torchaudio_F_spectrogram_1_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
52 51
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 211 value=0
aten::size op_1 2 1 waveform 211 107
prim::NumToTensor op_2 1 1 107 108
aten::Int op_3 1 1 108 112
prim::Constant op_4 0 1 212 value=-1
prim::ListConstruct op_5 2 1 212 112 113
aten::reshape op_6 2 1 waveform 113 input3.1
prim::Constant op_7 0 1 213 value=0
aten::size op_8 2 1 input3.1 213 116
prim::NumToTensor op_9 1 1 116 117
aten::Int op_10 1 1 117 120
prim::Constant op_11 0 1 214 value=1
aten::size op_12 2 1 input3.1 214 122
prim::NumToTensor op_13 1 1 122 123
aten::Int op_14 1 1 123 128
prim::Constant op_15 0 1 215 value=1
prim::ListConstruct op_16 3 1 215 120 128 129
aten::view op_17 2 1 input3.1 129 input4.1
prim::Constant op_18 0 1 216 value=%pad_left
prim::Constant op_19 0 1 217 value=%pad_right
prim::ListConstruct op_20 2 1 216 217 132
aten::reflection_pad1d op_21 2 1 input4.1 132 input5.1
prim::Constant op_22 0 1 218 value=1
aten::size op_23 2 1 input5.1 218 135
prim::NumToTensor op_24 1 1 135 136
aten::Int op_25 1 1 136 139
prim::Constant op_26 0 1 219 value=2
aten::size op_27 2 1 input5.1 219 141
prim::NumToTensor op_28 1 1 141 142
aten::Int op_29 1 1 142 147
prim::ListConstruct op_30 2 1 139 147 148
aten::view op_31 2 1 input5.1 148 input6.1
prim::Constant op_32 0 1 normalized value=%normalized
prim::Constant op_33 0 1 return_complex value=True
aten::stft op_34 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1
prim::Constant op_35 0 1 226 value=1
aten::size op_36 2 1 spec_f2.1 226 157
prim::NumToTensor op_37 1 1 157 158
aten::Int op_38 1 1 158 161
prim::Constant op_39 0 1 227 value=2
aten::size op_40 2 1 spec_f2.1 227 163
prim::NumToTensor op_41 1 1 163 164
aten::Int op_42 1 1 164 169
prim::ListConstruct op_43 2 1 161 169 170
aten::reshape op_44 2 1 spec_f2.1 170 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = 0;
op->params["pad_mode"] = "reflect";
op->params["center"] = true;
op->params["power"] = Parameter();
if (captured_params.at("normalized").b)
{
op->params["normalized"] = "frame_length";
}
else
{
op->params["normalized"] = false;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_2, 6)

class torchaudio_F_spectrogram_1_3 : public torchaudio_F_spectrogram_1_2
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
56 55
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 waveform 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 16
prim::Constant op_4 0 1 18 value=1
aten::size op_5 2 1 waveform 18 19
prim::NumToTensor op_6 1 1 19 20
aten::Int op_7 1 1 20 25
prim::Constant op_8 0 1 22 value=-1
prim::ListConstruct op_9 2 1 22 25 26
aten::reshape op_10 2 1 waveform 26 input.1
prim::Constant op_11 0 1 326 value=0
aten::size op_12 2 1 input.1 326 29
prim::NumToTensor op_13 1 1 29 30
aten::Int op_14 1 1 30 33
prim::Constant op_15 0 1 327 value=1
aten::size op_16 2 1 input.1 327 35
prim::NumToTensor op_17 1 1 35 36
aten::Int op_18 1 1 36 41
prim::Constant op_19 0 1 328 value=1
prim::ListConstruct op_20 3 1 328 33 41 42
aten::view op_21 2 1 input.1 42 input0.1
prim::Constant op_22 0 1 45 value=%pad_left
prim::Constant op_23 0 1 329 value=%pad_right
prim::ListConstruct op_24 2 1 45 329 46
aten::reflection_pad1d op_25 2 1 input0.1 46 input1.1
prim::Constant op_26 0 1 330 value=1
aten::size op_27 2 1 input1.1 330 49
prim::NumToTensor op_28 1 1 49 50
aten::Int op_29 1 1 50 53
prim::Constant op_30 0 1 55 value=2
aten::size op_31 2 1 input1.1 55 56
prim::NumToTensor op_32 1 1 56 57
aten::Int op_33 1 1 57 62
prim::ListConstruct op_34 2 1 53 62 63
aten::view op_35 2 1 input1.1 63 input2.1
prim::Constant op_36 0 1 normalized value=%normalized
prim::Constant op_37 0 1 return_complex value=True
aten::stft op_38 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_39 0 1 334 value=1
aten::size op_40 2 1 spec_f.1 334 74
prim::NumToTensor op_41 1 1 74 75
aten::Int op_42 1 1 75 78
prim::Constant op_43 0 1 335 value=2
aten::size op_44 2 1 spec_f.1 335 80
prim::NumToTensor op_45 1 1 80 81
aten::Int op_46 1 1 81 87
prim::ListConstruct op_47 3 1 16 78 87 88
aten::reshape op_48 2 1 spec_f.1 88 out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_3, 6)

class torchaudio_F_spectrogram_1_4 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
53 52
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 211 value=0
aten::size op_1 2 1 waveform 211 107
prim::NumToTensor op_2 1 1 107 108
aten::Int op_3 1 1 108 112
prim::Constant op_4 0 1 212 value=-1
prim::ListConstruct op_5 2 1 212 112 113
aten::reshape op_6 2 1 waveform 113 input3.1
prim::Constant op_7 0 1 213 value=0
aten::size op_8 2 1 input3.1 213 116
prim::NumToTensor op_9 1 1 116 117
aten::Int op_10 1 1 117 120
prim::Constant op_11 0 1 214 value=1
aten::size op_12 2 1 input3.1 214 122
prim::NumToTensor op_13 1 1 122 123
aten::Int op_14 1 1 123 128
prim::Constant op_15 0 1 215 value=1
prim::ListConstruct op_16 3 1 215 120 128 129
aten::view op_17 2 1 input3.1 129 input4.1
prim::Constant op_18 0 1 216 value=%pad_left
prim::Constant op_19 0 1 217 value=%pad_right
prim::ListConstruct op_20 2 1 216 217 132
prim::Constant op_21 0 1 46 value=0.000000e+00
aten::constant_pad_nd op_22 3 1 input4.1 132 46 input5.1
prim::Constant op_23 0 1 218 value=1
aten::size op_24 2 1 input5.1 218 135
prim::NumToTensor op_25 1 1 135 136
aten::Int op_26 1 1 136 139
prim::Constant op_27 0 1 219 value=2
aten::size op_28 2 1 input5.1 219 141
prim::NumToTensor op_29 1 1 141 142
aten::Int op_30 1 1 142 147
prim::ListConstruct op_31 2 1 139 147 148
aten::view op_32 2 1 input5.1 148 input6.1
prim::Constant op_33 0 1 normalized value=%normalized
prim::Constant op_34 0 1 return_complex value=True
aten::stft op_35 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1
prim::Constant op_36 0 1 226 value=1
aten::size op_37 2 1 spec_f2.1 226 157
prim::NumToTensor op_38 1 1 157 158
aten::Int op_39 1 1 158 161
prim::Constant op_40 0 1 227 value=2
aten::size op_41 2 1 spec_f2.1 227 163
prim::NumToTensor op_42 1 1 163 164
aten::Int op_43 1 1 164 169
prim::ListConstruct op_44 2 1 161 169 170
aten::reshape op_45 2 1 spec_f2.1 170 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = 0;
op->params["pad_mode"] = "constant";
op->params["center"] = true;
op->params["power"] = Parameter();
if (captured_params.at("normalized").b)
{
op->params["normalized"] = "frame_length";
}
else
{
op->params["normalized"] = false;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_4, 6)

class torchaudio_F_spectrogram_1_5 : public torchaudio_F_spectrogram_1_4
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
57 56
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
prim::Constant op_0 0 1 11 value=0
aten::size op_1 2 1 waveform 11 12
prim::NumToTensor op_2 1 1 12 13
aten::Int op_3 1 1 13 16
prim::Constant op_4 0 1 18 value=1
aten::size op_5 2 1 waveform 18 19
prim::NumToTensor op_6 1 1 19 20
aten::Int op_7 1 1 20 25
prim::Constant op_8 0 1 22 value=-1
prim::ListConstruct op_9 2 1 22 25 26
aten::reshape op_10 2 1 waveform 26 input.1
prim::Constant op_11 0 1 326 value=0
aten::size op_12 2 1 input.1 326 29
prim::NumToTensor op_13 1 1 29 30
aten::Int op_14 1 1 30 33
prim::Constant op_15 0 1 327 value=1
aten::size op_16 2 1 input.1 327 35
prim::NumToTensor op_17 1 1 35 36
aten::Int op_18 1 1 36 41
prim::Constant op_19 0 1 328 value=1
prim::ListConstruct op_20 3 1 328 33 41 42
aten::view op_21 2 1 input.1 42 input0.1
prim::Constant op_22 0 1 45 value=%pad_left
prim::Constant op_23 0 1 329 value=%pad_right
prim::ListConstruct op_24 2 1 45 329 46
prim::Constant op_25 0 1 47 value=0.000000e+00
aten::constant_pad_nd op_26 3 1 input0.1 46 47 input1.1
prim::Constant op_27 0 1 330 value=1
aten::size op_28 2 1 input1.1 330 49
prim::NumToTensor op_29 1 1 49 50
aten::Int op_30 1 1 50 53
prim::Constant op_31 0 1 55 value=2
aten::size op_32 2 1 input1.1 55 56
prim::NumToTensor op_33 1 1 56 57
aten::Int op_34 1 1 57 62
prim::ListConstruct op_35 2 1 53 62 63
aten::view op_36 2 1 input1.1 63 input2.1
prim::Constant op_37 0 1 normalized value=%normalized
prim::Constant op_38 0 1 return_complex value=True
aten::stft op_39 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1
prim::Constant op_40 0 1 334 value=1
aten::size op_41 2 1 spec_f.1 334 74
prim::NumToTensor op_42 1 1 74 75
aten::Int op_43 1 1 75 78
prim::Constant op_44 0 1 335 value=2
aten::size op_45 2 1 spec_f.1 335 80
prim::NumToTensor op_46 1 1 80 81
aten::Int op_47 1 1 81 87
prim::ListConstruct op_48 3 1 16 78 87 88
aten::reshape op_49 2 1 spec_f.1 88 out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_5, 6)

class torchaudio_F_spectrogram_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
14 13
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=False center=%center pad=%pad pad_mode=%pad_mode
prim::Constant op_1 0 1 92 value=2.000000e+00
aten::pow op_2 2 1 window 92 93
prim::Constant op_3 0 1 127 value=None
aten::sum op_4 2 1 93 127 95
aten::sqrt op_5 1 1 95 96
aten::div op_6 2 1 spec 96 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = captured_params.at("pad");
op->params["pad_mode"] = captured_params.at("pad_mode");
op->params["center"] = captured_params.at("center");
op->params["power"] = Parameter();
op->params["normalized"] = "window";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_2, 7)

class torchaudio_F_spectrogram_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode
aten::abs op_1 1 1 spec out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = captured_params.at("pad");
op->params["pad_mode"] = captured_params.at("pad_mode");
op->params["center"] = captured_params.at("center");
op->params["normalized"] = captured_params.at("normalized");
op->params["power"] = 1;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_3, 8)

class torchaudio_F_spectrogram_4 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
10 9
pnnx.Input input_0 0 1 waveform
pnnx.Input input_1 0 1 n_fft
pnnx.Input input_2 0 1 hop_length
pnnx.Input input_3 0 1 win_length
pnnx.Input input_4 0 1 window
pnnx.Input input_5 0 1 onesided
torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=1 normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode
prim::Constant op_1 0 1 391 value=2
aten::pow op_2 2 1 spec 391 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torchaudio.functional.spectrogram";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["pad"] = captured_params.at("pad");
op->params["pad_mode"] = captured_params.at("pad_mode");
op->params["center"] = captured_params.at("center");
op->params["normalized"] = captured_params.at("normalized");
op->params["power"] = 2;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_4, 9)

} // namespace pnnx

+ 203
- 0
tools/pnnx/src/pass_ncnn/torch_istft.cpp View File

@@ -0,0 +1,203 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_istft : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
torch.istft op_1 1 1 a out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length window=None
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InverseSpectrogram";
}

const char* name_str() const
{
return "istft";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 1; // returns
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = 0; // all ones
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft, 20)

class torch_istft_1 : public torch_istft
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
torch.istft op_1 1 1 a b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None
torch.view_as_real op_2 1 1 b out
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
torch_istft::write(op, captured_params);

op->params["1"] = 0; // returns
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_1, 20)

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

static int detect_window_type(const std::vector<float>& window_data)
{
const int winlen = (int)window_data.size();

bool is_one = true;
bool is_hann = true;
bool is_hamming = true;
for (int i = 0; i < winlen; i++)
{
if (!NearlyEqual(window_data[i], 1.f, 0.001))
is_one = false;

if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001))
is_hann = false;

if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001))
is_hamming = false;
}

if (is_one)
return 0;
if (is_hann)
return 1;
if (is_hamming)
return 2;

return -1;
}

class torch_istft_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
pnnx.Attribute op_1 0 1 window @data
torch.istft op_2 2 1 a window out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InverseSpectrogram";
}

const char* name_str() const
{
return "istft";
}

bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data();
const int window_type = detect_window_type(window_data);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_1.data").get_float32_data();
const int window_type = detect_window_type(window_data);

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 1; // returns
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_2, 20)

class torch_istft_3 : public torch_istft_2
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
torch.view_as_complex op_0 1 1 input a
pnnx.Attribute op_1 0 1 window @data
torch.istft op_2 2 1 a window b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length
torch.view_as_real op_3 1 1 b out
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
torch_istft_2::write(op, captured_params, captured_attrs);

op->params["1"] = 0; // returns
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_3, 20)

} // namespace ncnn

} // namespace pnnx

+ 176
- 0
tools/pnnx/src/pass_ncnn/torch_stft.cpp View File

@@ -0,0 +1,176 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_stft : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
torch.stft op_0 1 1 input a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None
torch.view_as_real op_1 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Spectrogram";
}

const char* name_str() const
{
return "stft";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::string& pad_mode = captured_params.at("pad_mode").s;
int pad_type = 2;
if (pad_mode == "constant")
pad_type = 0;
if (pad_mode == "replicate")
pad_type = 1;
if (pad_mode == "reflect")
pad_type = 2;
const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1;

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 0; // power
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = 0; // all ones
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["6"] = pad_type;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
op->params["8"] = onesided;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft, 20)

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

static int detect_window_type(const std::vector<float>& window_data)
{
const int winlen = (int)window_data.size();

bool is_one = true;
bool is_hann = true;
bool is_hamming = true;
for (int i = 0; i < winlen; i++)
{
if (!NearlyEqual(window_data[i], 1.f, 0.001))
is_one = false;

if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001))
is_hann = false;

if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001))
is_hamming = false;
}

if (is_one)
return 0;
if (is_hann)
return 1;
if (is_hamming)
return 2;

return -1;
}

class torch_stft_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 window @data
torch.stft op_1 2 1 input window a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length
torch.view_as_real op_2 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Spectrogram";
}

const char* name_str() const
{
return "stft";
}

bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);

const std::string& pad_mode = captured_params.at("pad_mode").s;
int pad_type = 2;
if (pad_mode == "constant")
pad_type = 0;
if (pad_mode == "replicate")
pad_type = 1;
if (pad_mode == "reflect")
pad_type = 2;
const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1;

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 0; // power
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["6"] = pad_type;
op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0;
op->params["8"] = onesided;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 127
- 0
tools/pnnx/src/pass_ncnn/torchaudio_F_inverse_spectrogram.cpp View File

@@ -0,0 +1,127 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

static int detect_window_type(const std::vector<float>& window_data)
{
const int winlen = (int)window_data.size();

bool is_one = true;
bool is_hann = true;
bool is_hamming = true;
for (int i = 0; i < winlen; i++)
{
if (!NearlyEqual(window_data[i], 1.f, 0.001))
is_one = false;

if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001))
is_hann = false;

if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001))
is_hamming = false;
}

if (is_one)
return 0;
if (is_hann)
return 1;
if (is_hamming)
return 2;

return -1;
}

class torchaudio_F_inverse_spectrogram : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 window @data
torch.view_as_complex op_1 1 1 input a
torchaudio.functional.inverse_spectrogram op_2 2 1 a window out center=%center hop_length=%hop_length length=None n_fft=%n_fft normalized=%normalized onesided=%onesided pad=0 win_length=%win_length
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InverseSpectrogram";
}

const char* name_str() const
{
return "inverse_spectrogram";
}

bool match(const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);

int normalized = 0;
if (captured_params.at("normalized").type == 1)
{
normalized = captured_params.at("normalized").b ? 2 : 0;
}
if (captured_params.at("normalized").type == 4)
{
if (captured_params.at("normalized").s == "frame_length")
normalized = 1;
if (captured_params.at("normalized").s == "window")
normalized = 2;
}

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 1; // returns
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["7"] = normalized;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 20)

} // namespace ncnn

} // namespace pnnx

+ 233
- 0
tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp View File

@@ -0,0 +1,233 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

static int detect_window_type(const std::vector<float>& window_data)
{
const int winlen = (int)window_data.size();

bool is_one = true;
bool is_hann = true;
bool is_hamming = true;
for (int i = 0; i < winlen; i++)
{
if (!NearlyEqual(window_data[i], 1.f, 0.001))
is_one = false;

if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001))
is_hann = false;

if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001))
is_hamming = false;
}

if (is_one)
return 0;
if (is_hann)
return 1;
if (is_hamming)
return 2;

return -1;
}

class torchaudio_F_spectrogram : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 window @data
torchaudio.functional.spectrogram op_1 2 1 input window a n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode
torch.view_as_real op_2 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Spectrogram";
}

const char* name_str() const
{
return "spectrogram";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
if (captured_params.at("power").type != 0)
return false;

const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);

const std::string& pad_mode = captured_params.at("pad_mode").s;
int pad_type = 2;
if (pad_mode == "constant")
pad_type = 0;
if (pad_mode == "replicate")
pad_type = 1;
if (pad_mode == "reflect")
pad_type = 2;
const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1;
int normalized = 0;
if (captured_params.at("normalized").type == 1)
{
normalized = captured_params.at("normalized").b ? 2 : 0;
}
if (captured_params.at("normalized").type == 4)
{
if (captured_params.at("normalized").s == "frame_length")
normalized = 1;
if (captured_params.at("normalized").s == "window")
normalized = 2;
}

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = 0; // power
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["6"] = pad_type;
op->params["7"] = normalized;
op->params["8"] = onesided;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 20)

class torchaudio_F_spectrogram_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 window @data
torchaudio.functional.spectrogram op_1 2 1 input window out n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Spectrogram";
}

const char* name_str() const
{
return "spectrogram";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
if (captured_params.at("power").type == 0)
return false;

const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);
return window_type != -1;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<float> window_data = captured_attrs.at("op_0.data").get_float32_data();
const int window_type = detect_window_type(window_data);

const std::string& pad_mode = captured_params.at("pad_mode").s;
int pad_type = 2;
if (pad_mode == "constant")
pad_type = 0;
if (pad_mode == "replicate")
pad_type = 1;
if (pad_mode == "reflect")
pad_type = 2;
const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1;
int normalized = 0;
if (captured_params.at("normalized").type == 1)
{
normalized = captured_params.at("normalized").b ? 2 : 0;
}
if (captured_params.at("normalized").type == 4)
{
if (captured_params.at("normalized").s == "frame_length")
normalized = 1;
if (captured_params.at("normalized").s == "window")
normalized = 2;
}

int power = 0;
if (captured_params.at("power").type == 2)
{
power = captured_params.at("power").i;
if (power != 1 && power != 2)
fprintf(stderr, "unsupported spectrogram power %d\n", power);
}
if (captured_params.at("power").type == 3)
{
if (NearlyEqual(captured_params.at("power").f, 1.0, 0.0001))
power = 1;
else if (NearlyEqual(captured_params.at("power").f, 2.0, 0.0001))
power = 2;
else
fprintf(stderr, "unsupported spectrogram power %f\n", captured_params.at("power").f);
}

op->params["0"] = captured_params.at("n_fft");
op->params["1"] = power;
op->params["2"] = captured_params.at("hop_length");
op->params["3"] = captured_params.at("win_length");
op->params["4"] = window_type;
op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0;
op->params["6"] = pad_type;
op->params["7"] = normalized;
op->params["8"] = onesided;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 5
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -360,6 +360,11 @@ if(TorchVision_FOUND)
pnnx_add_test(torchvision_RoIAlign)
endif()

pnnx_add_test(torchaudio_F_inverse_spectrogram)
pnnx_add_test(torchaudio_F_spectrogram)
pnnx_add_test(torchaudio_InverseSpectrogram)
pnnx_add_test(torchaudio_Spectrogram)

add_subdirectory(ncnn)

if(onnxruntime_FOUND)


+ 8
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -175,6 +175,9 @@ pnnx_ncnn_add_test(torch_transpose)
pnnx_ncnn_add_test(torch_unbind)
pnnx_ncnn_add_test(torch_unsqueeze)

pnnx_ncnn_add_test(torch_istft)
pnnx_ncnn_add_test(torch_stft)

pnnx_ncnn_add_test(torch_abs)
pnnx_ncnn_add_test(torch_acos)
pnnx_ncnn_add_test(torch_asin)
@@ -217,3 +220,8 @@ pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast)
if(TorchVision_FOUND)
pnnx_ncnn_add_test(torchvision_DeformConv2d)
endif()

pnnx_ncnn_add_test(torchaudio_F_inverse_spectrogram)
pnnx_ncnn_add_test(torchaudio_F_spectrogram)
pnnx_ncnn_add_test(torchaudio_InverseSpectrogram)
pnnx_ncnn_add_test(torchaudio_Spectrogram)

+ 68
- 0
tools/pnnx/tests/ncnn/test_torch_istft.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x = torch.view_as_complex(x)
y = torch.view_as_complex(y)
z = torch.view_as_complex(z)
w = torch.view_as_complex(w)
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False)
out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False)
out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False)
out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True)
out3 = torch.view_as_real(out3)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(33, 161, 2)
y = torch.rand(65, 77, 2)
z = torch.rand(257, 8, 2)
w = torch.rand(512, 4, 2)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_istft.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_istft.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[512,4,2]")

# ncnn inference
import test_torch_istft_ncnn
b = test_torch_istft_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 65
- 0
tools/pnnx/tests/ncnn/test_torch_stft.py View File

@@ -0,0 +1,65 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True)
out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True)
out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True)
out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True)
out0 = torch.view_as_real(out0)
out1 = torch.view_as_real(out1)
out2 = torch.view_as_real(out2)
out3 = torch.view_as_real(out3)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(2560)
y = torch.rand(1000)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torch_stft.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_stft.pt inputshape=[2560],[1000]")

# ncnn inference
import test_torch_stft_ncnn
b = test_torch_stft_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 72
- 0
tools/pnnx/tests/ncnn/test_torchaudio_F_inverse_spectrogram.py View File

@@ -0,0 +1,72 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x = torch.view_as_complex(x)
y = torch.view_as_complex(y)
z = torch.view_as_complex(z)
w = torch.view_as_complex(w)
out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None)
out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None)
out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None)
out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=1024, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False, length=None)
return out0, out1, out2, out3

def test():
if version.parse(torchaudio.__version__) < version.parse('0.10.0'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(33, 161, 2)
y = torch.rand(65, 77, 2)
z = torch.rand(257, 8, 2)
w = torch.rand(513, 4, 2)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torchaudio_F_inverse_spectrogram.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]")

# ncnn inference
import test_torchaudio_F_inverse_spectrogram_ncnn
b = test_torchaudio_F_inverse_spectrogram_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 63
- 0
tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1)
out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None)
out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2)
out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2)
out1 = torch.view_as_real(out1)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(2560)
y = torch.rand(1000)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torchaudio_F_spectrogram.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[2560],[1000]")

# ncnn inference
import test_torchaudio_F_spectrogram_ncnn
b = test_torchaudio_F_spectrogram_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 77
- 0
tools/pnnx/tests/ncnn/test_torchaudio_InverseSpectrogram.py View File

@@ -0,0 +1,77 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window')
self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False)
self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length')
self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=1024, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False)

def forward(self, x, y, z, w):
x = torch.view_as_complex(x)
y = torch.view_as_complex(y)
z = torch.view_as_complex(z)
w = torch.view_as_complex(w)
out0 = self.s0(x)
out1 = self.s1(y)
out2 = self.s2(z)
out3 = self.s3(w)
return out0, out1, out2, out3

def test():
if version.parse(torchaudio.__version__) < version.parse('0.10.0'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(33, 161, 2)
y = torch.rand(65, 77, 2)
z = torch.rand(257, 8, 2)
w = torch.rand(513, 4, 2)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torchaudio_InverseSpectrogram.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]")

# ncnn inference
import test_torchaudio_InverseSpectrogram_ncnn
b = test_torchaudio_InverseSpectrogram_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 68
- 0
tools/pnnx/tests/ncnn/test_torchaudio_Spectrogram.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1)
self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None)
self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2)
self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2)

def forward(self, x, y):
out0 = self.s0(x)
out1 = self.s1(x)
out2 = self.s2(y)
out3 = self.s3(y)
out1 = torch.view_as_real(out1)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(2560)
y = torch.rand(1000)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torchaudio_Spectrogram.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[2560],[1000]")

# ncnn inference
import test_torchaudio_Spectrogram_ncnn
b = test_torchaudio_Spectrogram_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 3
- 3
tools/pnnx/tests/test_torch_istft.py View File

@@ -21,9 +21,9 @@ class Model(nn.Module):
super(Model, self).__init__()

def forward(self, x, y, z, w):
out0 = torch.istft(x, n_fft=64, center=True, normalized=True, return_complex=False)
out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False)
out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False)
out2 = torch.istft(z, n_fft=512, center=True, onesided=True, return_complex=False)
out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False)
out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True)
return out0, out1, out2, out3

@@ -52,7 +52,7 @@ def test():
b = test_torch_istft_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True



+ 4
- 4
tools/pnnx/tests/test_torch_stft.py View File

@@ -21,10 +21,10 @@ class Model(nn.Module):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torch.stft(x, n_fft=64, center=True, pad_mode='reflect', normalized=True, return_complex=True)
out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True)
out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True)
out2 = torch.stft(y, n_fft=512, center=True, pad_mode='constant', onesided=True, return_complex=True)
out3 = torch.stft(y, n_fft=512, center=False, onesided=False, return_complex=True)
out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True)
out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True)
return out0, out1, out2, out3

def test():
@@ -50,7 +50,7 @@ def test():
b = test_torch_stft_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True



+ 68
- 0
tools/pnnx/tests/test_torchaudio_F_inverse_spectrogram.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None)
out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None)
out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None)
out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False, length=None)
return out0, out1, out2, out3

def test():
if version.parse(torchaudio.__version__) < version.parse('0.10.0'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 33, 161, dtype=torch.complex64)
y = torch.rand(1, 65, 77, dtype=torch.complex64)
z = torch.rand(257, 8, dtype=torch.complex64)
w = torch.rand(512, 4, dtype=torch.complex64)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torchaudio_F_inverse_spectrogram.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64")

# pnnx inference
import test_torchaudio_F_inverse_spectrogram_pnnx
b = test_torchaudio_F_inverse_spectrogram_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 62
- 0
tools/pnnx/tests/test_torchaudio_F_spectrogram.py View File

@@ -0,0 +1,62 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1)
out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None)
out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2)
out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 2560)
y = torch.rand(1000)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torchaudio_F_spectrogram.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[3,2560],[1000]")

# pnnx inference
import test_torchaudio_F_spectrogram_pnnx
b = test_torchaudio_F_spectrogram_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 73
- 0
tools/pnnx/tests/test_torchaudio_InverseSpectrogram.py View File

@@ -0,0 +1,73 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window')
self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False)
self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length')
self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False)

def forward(self, x, y, z, w):
out0 = self.s0(x)
out1 = self.s1(y)
out2 = self.s2(z)
out3 = self.s3(w)
return out0, out1, out2, out3

def test():
if version.parse(torchaudio.__version__) < version.parse('0.10.0'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 33, 161, dtype=torch.complex64)
y = torch.rand(1, 65, 77, dtype=torch.complex64)
z = torch.rand(257, 8, dtype=torch.complex64)
w = torch.rand(512, 4, dtype=torch.complex64)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torchaudio_InverseSpectrogram.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64")

# pnnx inference
import test_torchaudio_InverseSpectrogram_pnnx
b = test_torchaudio_InverseSpectrogram_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 67
- 0
tools/pnnx/tests/test_torchaudio_Spectrogram.py View File

@@ -0,0 +1,67 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1)
self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None)
self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2)
self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2)

def forward(self, x, y):
out0 = self.s0(x)
out1 = self.s1(x)
out2 = self.s2(y)
out3 = self.s3(y)
return out0, out1, out2, out3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 2560)
y = torch.rand(1000)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_torchaudio_Spectrogram.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[3,2560],[1000]")

# pnnx inference
import test_torchaudio_Spectrogram_pnnx
b = test_torchaudio_Spectrogram_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save