diff --git a/src/layer/spectrogram.cpp b/src/layer/spectrogram.cpp index 2ac655f44..e4ff3af0d 100644 --- a/src/layer/spectrogram.cpp +++ b/src/layer/spectrogram.cpp @@ -207,4 +207,4 @@ int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& op return 0; } -} // namespace ncnn +} // namespace ncnn \ No newline at end of file diff --git a/src/layer/spectrogram.h b/src/layer/spectrogram.h index 62c48dc7f..260aa008c 100644 --- a/src/layer/spectrogram.h +++ b/src/layer/spectrogram.h @@ -33,4 +33,4 @@ public: } // namespace ncnn -#endif // LAYER_SPECTROGRAM_H +#endif // LAYER_SPECTROGRAM_H \ No newline at end of file diff --git a/src/layer/x86/spectrogram_x86.cpp b/src/layer/x86/spectrogram_x86.cpp new file mode 100644 index 000000000..7acd68f3e --- /dev/null +++ b/src/layer/x86/spectrogram_x86.cpp @@ -0,0 +1,277 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "spectrogram_x86.h" + +namespace ncnn { + +Spectrogram_x86::Spectrogram_x86() + : conv_transpose(0) +{ + one_blob_only = true; + support_inplace = false; +} + +Spectrogram_x86::~Spectrogram_x86() +{ + delete conv_transpose; +} + +int Spectrogram_x86::load_param(const ParamDict& pd) +{ + n_fft = pd.get(0, 0); + power = pd.get(1, 0); + hoplen = pd.get(2, n_fft / 4); + winlen = pd.get(3, n_fft); + window_type = pd.get(4, 0); + center = pd.get(5, 1); + pad_type = pd.get(6, 2); + normalized = pd.get(7, 0); + onesided = pd.get(8, 1); + + // assert winlen <= n_fft + // generate window + window_data.create(n_fft); + { + float* p = window_data; + for (int i = 0; i < (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + if (window_type == 0) + { + // all ones + for (int i = 0; i < winlen; i++) + { + *p++ = 1.f; + } + } + if (window_type == 1) + { + // hann window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen)); + } + } + if (window_type == 2) + { + // hamming window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen); + } + } + for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + + // pre-calculated window norm factor + if (normalized == 2) + { + float sqsum = 0.f; + for (int i = 0; i < n_fft; i++) + { + sqsum += window_data[i] * window_data[i]; + } + float scale = 1.f / sqrt(sqsum); + + for (int i = 0; i < n_fft; i++) + { + window_data[i] *= scale; + } + } + } + + Mat theta; + if (onesided) + { + n_freq = n_fft / 2 + 1; + } + else + { + n_freq = n_fft; + } + theta.create(n_fft, n_freq, size_t(8)); + + for (int i = 0; i < n_freq; i++) + { + for (int j = 0; j < n_fft; j++) + { + theta.row(i)[j] = 2 * 3.14159265358979323846 * i * j / n_fft; + } + } + + Mat real_basis, imag_basis; + real_basis.create(n_fft, n_freq, size_t(8)); + imag_basis.create(n_fft, n_freq, size_t(8)); + + for (int i = 0; i < n_freq; i++) + { + for (int j = 0; j < n_fft; j++) + { + real_basis.row(i)[j] = cos(theta.row(i)[j]); + imag_basis.row(i)[j] = -sin(theta.row(i)[j]); + } + } + + // multiply window + for (int i = 0; i < n_freq; i++) + { + for (int j = 0; j < n_fft; j++) + { + real_basis.row(i)[j] *= window_data[j]; + imag_basis.row(i)[j] *= window_data[j]; + } + } + + if (normalized == 1) + { + double scale = 1.f / sqrt(n_fft); + for (int i = 0; i < n_freq; i++) + { + for (int j = 0; j < n_fft; j++) + { + real_basis.row(i)[j] *= scale; + imag_basis.row(i)[j] *= scale; + } + } + } + + conv_data.create(n_fft, 1, n_freq * 2); + + for (int i = 0; i < n_freq; i++) + { + for (int j = 0; j < n_fft; j++) + { + conv_data.channel(i).row(0)[j] = (float)real_basis.row(i)[j]; + conv_data.channel(i + n_freq).row(0)[j] = (float)imag_basis.row(i)[j]; + } + } + + conv_transpose = ncnn::create_layer("Convolution1D"); + ncnn::ParamDict conv_transpose_pd; + + conv_transpose_pd.set(0, 2 * n_freq); // num_output + conv_transpose_pd.set(1, n_fft); // kernel_w + conv_transpose_pd.set(3, hoplen); // stride_w + conv_transpose_pd.set(19, 1); // dynamic_weight + + conv_transpose->load_param(conv_transpose_pd); + + return 0; +} + +int Spectrogram_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + // https://pytorch.org/audio/stable/generated/torchaudio.functional.spectrogram.html + + // TODO custom window + + Mat bottom_blob_bordered = bottom_blob; + if (center == 1) + { + Option opt_b = opt; + opt_b.blob_allocator = opt.workspace_allocator; + if (pad_type == 0) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_CONSTANT, 0.f, opt_b); + if (pad_type == 1) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REPLICATE, 0.f, opt_b); + if (pad_type == 2) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REFLECT, 0.f, opt_b); + } + + const int size = bottom_blob_bordered.w; + + // const int frames = size / hoplen + 1; + const int frames = (size - n_fft) / hoplen + 1; + + const size_t elemsize = bottom_blob_bordered.elemsize; + + if (elemsize != sizeof(float)) + { + return -100; + } + + if (power == 0) + { + top_blob.create(2, frames, n_freq, elemsize, opt.blob_allocator); + } + else + { + top_blob.create(frames, n_freq, elemsize, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + std::vector inputs; + inputs.push_back(bottom_blob_bordered); + inputs.push_back(conv_data); + + std::vector outputs; + outputs.push_back(Mat()); + + Option opt_conv = opt; + opt_conv.use_packing_layout = false; + + conv_transpose->create_pipeline(opt_conv); + conv_transpose->forward(inputs, outputs, opt_conv); + conv_transpose->destroy_pipeline(opt_conv); + + Mat conv_top_blob = outputs[0]; // (2 * n_freq, frames) + float* conv_top_data = conv_top_blob; + + if (power == 0) // as complex + { + // copy + for (int i = 0; i < frames; i++) + { + for (int j = 0; j < n_freq; j++) + { + top_blob.channel(j).row(i)[0] = conv_top_data[j * frames + i]; + top_blob.channel(j).row(i)[1] = conv_top_data[(j + n_freq) * frames + i]; + } + } + } + else + { + if (power == 1) // magnitude sqrt(re * re + im * im); + { + // copy + for (int i = 0; i < frames; i++) + { + for (int j = 0; j < n_freq; j++) + { + top_blob.row(j)[i] = sqrtf(conv_top_data[j * frames + i] * conv_top_data[j * frames + i] + conv_top_data[(j + n_freq) * frames + i] * conv_top_data[(j + n_freq) * frames + i]); + } + } + } + else if (power == 2) // power re * re + im * im; + { + // copy + for (int i = 0; i < frames; i++) + { + for (int j = 0; j < n_freq; j++) + { + top_blob.row(j)[i] = conv_top_data[j * frames + i] * conv_top_data[j * frames + i] + conv_top_data[(j + n_freq) * frames + i] * conv_top_data[(j + n_freq) * frames + i]; + } + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/spectrogram_x86.h b/src/layer/x86/spectrogram_x86.h new file mode 100644 index 000000000..4678406c8 --- /dev/null +++ b/src/layer/x86/spectrogram_x86.h @@ -0,0 +1,53 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_SPECTROGRAM_X86_H +#define LAYER_SPECTROGRAM_X86_H + +#include "spectrogram.h" + +namespace ncnn { + +class Spectrogram_x86 : public Spectrogram +{ +public: + Spectrogram_x86(); + ~Spectrogram_x86(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + int n_fft; + int power; + int hoplen; + int winlen; + int window_type; // 0=ones 1=hann 2=hamming + int center; + int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT + int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy + int onesided; + + int n_freq; + + Mat window_data; + Mat conv_data; + + Layer* conv_transpose; +}; + +} // namespace ncnn + +#endif // LAYER_SPECTROGRAM_X86_H diff --git a/tests/test_spectrogram.cpp b/tests/test_spectrogram.cpp index 6bf8e979f..8d6e7c9d7 100644 --- a/tests/test_spectrogram.cpp +++ b/tests/test_spectrogram.cpp @@ -39,9 +39,77 @@ static int test_spectrogram_0() || test_spectrogram(124, 55, 2, 12, 55, 1, 1, 2, 2, 0); } +static int test_spectrogram_eval(int size, int n_fft, int power, int hoplen, int winlen, int window_type, int center, int pad_type, int normalized, int onesided, float* in, float* std) +{ + ncnn::Layer* layer = ncnn::create_layer("Spectrogram"); + + ncnn::ParamDict pd; + pd.set(0, n_fft); + pd.set(1, power); + pd.set(2, hoplen); + pd.set(3, winlen); + pd.set(4, window_type); + pd.set(5, center); + pd.set(6, pad_type); + pd.set(7, normalized); + pd.set(8, onesided); + + ncnn::Mat input = ncnn::Mat(size); + memcpy(input, in, size * sizeof(float)); + + ncnn::Mat output; + + ncnn::Option opt; + opt.num_threads = 2; + + layer->load_param(pd); + layer->create_pipeline(opt); + layer->forward(input, output, opt); + layer->destroy_pipeline(opt); + + const float epsilon = 1e-6; + + for (int i = 0; i < output.c; i++) + { + float* output_data = output.channel(i); + for (int j = 0; j < output.h; j++) + { + for (int k = 0; k < output.w; k++) + { + if (fabs(output_data[j * output.w + k] - std[i * output.h * output.w + j * output.w + k]) > epsilon) + { + fprintf(stderr, "test_spectrogram failed size=%d n_fft=%d power=%d hoplen=%d winlen=%d window_type=%d center=%d pad_type=%d normalized=%d onesided=%d\n", size, n_fft, power, hoplen, winlen, window_type, center, pad_type, normalized, onesided); + return 1; + } + } + } + } + + delete layer; + return 0; +} + +static int test_spectrogram_1() +{ + float input_0[16] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f}; + float std_0[] = { + 0.05000000f, 0.40000001f, 0.80000001f, 1.20000005f, 1.59999990f, 2.00000000f, 2.40000010f, 2.79999995f, 0.75000000f, 0.05000000f, 0.22360681f, 0.41231057f, 0.60827625f, 0.80622578f, 1.00498760f, 1.20415950f, 1.40356684f, 0.75000000f, 0.05000000f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00000006f, 0.00000000f, 0.00000000f, 0.00000000f, 0.75000000f + }; + float std_1[] = { + 0.80000001f, 1.20000005f, 1.59999990f, 2.00000000f, 2.40000010f, 0.68649411f, 1.02670193f, 1.36751485f, 1.70857072f, 2.04974818f, 0.41231057f, 0.60827625f, 0.80622578f, 1.00498760f, 1.20415950f, 0.13684234f, 0.18942842f, 0.24475159f, 0.30130789f, 0.35851428f, 0.00000000f, 0.00000000f, 0.00000006f, 0.00000000f, 0.00000000f + }; + float std_2[] = { + 0.28284273f, 0.49497476f, 0.70710677f, 0.24271232f, 0.42322639f, 0.60407096f, 0.14577380f, 0.25000000f, 0.35531676f, 0.04838108f, 0.07667736f, 0.10652842f, 0.00000000f, 0.00000002f, 0.00000000f, 0.04838108f, 0.07667736f, 0.10652842f, 0.14577380f, 0.25000000f, 0.35531676f, 0.24271232f, 0.42322639f, 0.60407096f + }; + + return test_spectrogram_eval(16, 4, 1, 2, 4, 1, 1, 0, 0, 1, input_0, std_0) + || test_spectrogram_eval(16, 8, 1, 2, 4, 1, 0, 0, 0, 1, input_0, std_1) + || test_spectrogram_eval(16, 8, 1, 3, 4, 1, 0, 0, 1, 0, input_0, std_2); +} + int main() { SRAND(7767517); - return test_spectrogram_0(); + return test_spectrogram_0() || test_spectrogram_1(); }