Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9178339 * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * feat: add unittest for ANS pipeline * Merge branch 'master' into dev/ans * add new SoundFile to audio dependency * Merge branch 'master' into dev/ans * use ANS pipeline name from metainfo * Merge branch 'master' into dev/ans * chore: update docstring of ANS module * Merge branch 'master' into dev/ans * refactor: use names from metainfo * refactor: enable ans unittest * refactor: add more log message in unittestmaster
| @@ -21,6 +21,7 @@ class Models(object): | |||
| sambert_hifi_16k = 'sambert-hifi-16k' | |||
| generic_tts_frontend = 'generic-tts-frontend' | |||
| hifigan16k = 'hifigan16k' | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| kws_kwsbp = 'kws-kwsbp' | |||
| # multi-modal models | |||
| @@ -55,6 +56,7 @@ class Pipelines(object): | |||
| # audio tasks | |||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | |||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| kws_kwsbp = 'kws-kwsbp' | |||
| # multi-modal tasks | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .audio.ans.frcrn import FRCRNModel | |||
| from .audio.kws import GenericKeyWordSpotting | |||
| from .audio.tts.am import SambertNetHifi16k | |||
| from .audio.tts.vocoder import Hifigan16k | |||
| @@ -0,0 +1,248 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class UniDeepFsmn(nn.Module): | |||
| def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): | |||
| super(UniDeepFsmn, self).__init__() | |||
| self.input_dim = input_dim | |||
| self.output_dim = output_dim | |||
| if lorder is None: | |||
| return | |||
| self.lorder = lorder | |||
| self.hidden_size = hidden_size | |||
| self.linear = nn.Linear(input_dim, hidden_size) | |||
| self.project = nn.Linear(hidden_size, output_dim, bias=False) | |||
| self.conv1 = nn.Conv2d( | |||
| output_dim, | |||
| output_dim, [lorder, 1], [1, 1], | |||
| groups=output_dim, | |||
| bias=False) | |||
| def forward(self, input): | |||
| r""" | |||
| Args: | |||
| input: torch with shape: batch (b) x sequence(T) x feature (h) | |||
| Returns: | |||
| batch (b) x channel (c) x sequence(T) x feature (h) | |||
| """ | |||
| f1 = F.relu(self.linear(input)) | |||
| p1 = self.project(f1) | |||
| x = torch.unsqueeze(p1, 1) | |||
| # x: batch (b) x channel (c) x sequence(T) x feature (h) | |||
| x_per = x.permute(0, 3, 2, 1) | |||
| # x_per: batch (b) x feature (h) x sequence(T) x channel (c) | |||
| y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) | |||
| out = x_per + self.conv1(y) | |||
| out1 = out.permute(0, 3, 2, 1) | |||
| # out1: batch (b) x channel (c) x sequence(T) x feature (h) | |||
| return input + out1.squeeze() | |||
| class ComplexUniDeepFsmn(nn.Module): | |||
| def __init__(self, nIn, nHidden=128, nOut=128): | |||
| super(ComplexUniDeepFsmn, self).__init__() | |||
| self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) | |||
| self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) | |||
| self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) | |||
| self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) | |||
| def forward(self, x): | |||
| r""" | |||
| Args: | |||
| x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] | |||
| Returns: | |||
| [batch, feature, sequence, 2], eg: [6, 99, 1024, 2] | |||
| """ | |||
| # | |||
| b, c, h, T, d = x.size() | |||
| x = torch.reshape(x, (b, c * h, T, d)) | |||
| # x: [b,h,T,2], [6, 256, 106, 2] | |||
| x = torch.transpose(x, 1, 2) | |||
| # x: [b,T,h,2], [6, 106, 256, 2] | |||
| real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) | |||
| imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) | |||
| # GRU output: [99, 6, 128] | |||
| real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1) | |||
| imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1) | |||
| # output: [b,T,h,2], [99, 6, 1024, 2] | |||
| output = torch.stack((real, imaginary), dim=-1) | |||
| # output: [b,h,T,2], [6, 99, 1024, 2] | |||
| output = torch.transpose(output, 1, 2) | |||
| output = torch.reshape(output, (b, c, h, T, d)) | |||
| return output | |||
| class ComplexUniDeepFsmn_L1(nn.Module): | |||
| def __init__(self, nIn, nHidden=128, nOut=128): | |||
| super(ComplexUniDeepFsmn_L1, self).__init__() | |||
| self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) | |||
| self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) | |||
| def forward(self, x): | |||
| r""" | |||
| Args: | |||
| x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] | |||
| """ | |||
| b, c, h, T, d = x.size() | |||
| # x : [b,T,h,c,2] | |||
| x = torch.transpose(x, 1, 3) | |||
| x = torch.reshape(x, (b * T, h, c, d)) | |||
| real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) | |||
| imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) | |||
| # output: [b*T,h,c,2], [6*106, h, 256, 2] | |||
| output = torch.stack((real, imaginary), dim=-1) | |||
| output = torch.reshape(output, (b, T, h, c, d)) | |||
| output = torch.transpose(output, 1, 3) | |||
| return output | |||
| class ComplexConv2d(nn.Module): | |||
| # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py | |||
| def __init__(self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| groups=1, | |||
| bias=True, | |||
| **kwargs): | |||
| super().__init__() | |||
| # Model components | |||
| self.conv_re = nn.Conv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias, | |||
| **kwargs) | |||
| self.conv_im = nn.Conv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias, | |||
| **kwargs) | |||
| def forward(self, x): | |||
| r""" | |||
| Args: | |||
| x: torch with shape: [batch,channel,axis1,axis2,2] | |||
| """ | |||
| real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) | |||
| imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) | |||
| output = torch.stack((real, imaginary), dim=-1) | |||
| return output | |||
| class ComplexConvTranspose2d(nn.Module): | |||
| def __init__(self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| output_padding=0, | |||
| dilation=1, | |||
| groups=1, | |||
| bias=True, | |||
| **kwargs): | |||
| super().__init__() | |||
| # Model components | |||
| self.tconv_re = nn.ConvTranspose2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| groups=groups, | |||
| bias=bias, | |||
| dilation=dilation, | |||
| **kwargs) | |||
| self.tconv_im = nn.ConvTranspose2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| groups=groups, | |||
| bias=bias, | |||
| dilation=dilation, | |||
| **kwargs) | |||
| def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] | |||
| real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) | |||
| imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) | |||
| output = torch.stack((real, imaginary), dim=-1) | |||
| return output | |||
| class ComplexBatchNorm2d(nn.Module): | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.1, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| **kwargs): | |||
| super().__init__() | |||
| self.bn_re = nn.BatchNorm2d( | |||
| num_features=num_features, | |||
| momentum=momentum, | |||
| affine=affine, | |||
| eps=eps, | |||
| track_running_stats=track_running_stats, | |||
| **kwargs) | |||
| self.bn_im = nn.BatchNorm2d( | |||
| num_features=num_features, | |||
| momentum=momentum, | |||
| affine=affine, | |||
| eps=eps, | |||
| track_running_stats=track_running_stats, | |||
| **kwargs) | |||
| def forward(self, x): | |||
| real = self.bn_re(x[..., 0]) | |||
| imag = self.bn_im(x[..., 1]) | |||
| output = torch.stack((real, imag), dim=-1) | |||
| return output | |||
| @@ -0,0 +1,112 @@ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from scipy.signal import get_window | |||
| def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): | |||
| if win_type == 'None' or win_type is None: | |||
| window = np.ones(win_len) | |||
| else: | |||
| window = get_window(win_type, win_len, fftbins=True)**0.5 | |||
| N = fft_len | |||
| fourier_basis = np.fft.rfft(np.eye(N))[:win_len] | |||
| real_kernel = np.real(fourier_basis) | |||
| imag_kernel = np.imag(fourier_basis) | |||
| kernel = np.concatenate([real_kernel, imag_kernel], 1).T | |||
| if invers: | |||
| kernel = np.linalg.pinv(kernel).T | |||
| kernel = kernel * window | |||
| kernel = kernel[:, None, :] | |||
| return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( | |||
| window[None, :, None].astype(np.float32)) | |||
| class ConvSTFT(nn.Module): | |||
| def __init__(self, | |||
| win_len, | |||
| win_inc, | |||
| fft_len=None, | |||
| win_type='hamming', | |||
| feature_type='real', | |||
| fix=True): | |||
| super(ConvSTFT, self).__init__() | |||
| if fft_len is None: | |||
| self.fft_len = np.int(2**np.ceil(np.log2(win_len))) | |||
| else: | |||
| self.fft_len = fft_len | |||
| kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) | |||
| self.weight = nn.Parameter(kernel, requires_grad=(not fix)) | |||
| self.feature_type = feature_type | |||
| self.stride = win_inc | |||
| self.win_len = win_len | |||
| self.dim = self.fft_len | |||
| def forward(self, inputs): | |||
| if inputs.dim() == 2: | |||
| inputs = torch.unsqueeze(inputs, 1) | |||
| outputs = F.conv1d(inputs, self.weight, stride=self.stride) | |||
| if self.feature_type == 'complex': | |||
| return outputs | |||
| else: | |||
| dim = self.dim // 2 + 1 | |||
| real = outputs[:, :dim, :] | |||
| imag = outputs[:, dim:, :] | |||
| mags = torch.sqrt(real**2 + imag**2) | |||
| phase = torch.atan2(imag, real) | |||
| return mags, phase | |||
| class ConviSTFT(nn.Module): | |||
| def __init__(self, | |||
| win_len, | |||
| win_inc, | |||
| fft_len=None, | |||
| win_type='hamming', | |||
| feature_type='real', | |||
| fix=True): | |||
| super(ConviSTFT, self).__init__() | |||
| if fft_len is None: | |||
| self.fft_len = np.int(2**np.ceil(np.log2(win_len))) | |||
| else: | |||
| self.fft_len = fft_len | |||
| kernel, window = init_kernels( | |||
| win_len, win_inc, self.fft_len, win_type, invers=True) | |||
| self.weight = nn.Parameter(kernel, requires_grad=(not fix)) | |||
| self.feature_type = feature_type | |||
| self.win_type = win_type | |||
| self.win_len = win_len | |||
| self.win_inc = win_inc | |||
| self.stride = win_inc | |||
| self.dim = self.fft_len | |||
| self.register_buffer('window', window) | |||
| self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) | |||
| def forward(self, inputs, phase=None): | |||
| """ | |||
| Args: | |||
| inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) | |||
| phase: [B, N//2+1, T] (if not none) | |||
| """ | |||
| if phase is not None: | |||
| real = inputs * torch.cos(phase) | |||
| imag = inputs * torch.sin(phase) | |||
| inputs = torch.cat([real, imag], 1) | |||
| outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) | |||
| # this is from torch-stft: https://github.com/pseeth/torch-stft | |||
| t = self.window.repeat(1, 1, inputs.size(-1))**2 | |||
| coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) | |||
| outputs = outputs / (coff + 1e-8) | |||
| return outputs | |||
| @@ -0,0 +1,309 @@ | |||
| import os | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from ...base import Model, Tensor | |||
| from .conv_stft import ConviSTFT, ConvSTFT | |||
| from .unet import UNet | |||
| class FTB(nn.Module): | |||
| def __init__(self, input_dim=257, in_channel=9, r_channel=5): | |||
| super(FTB, self).__init__() | |||
| self.in_channel = in_channel | |||
| self.conv1 = nn.Sequential( | |||
| nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), | |||
| nn.BatchNorm2d(r_channel), nn.ReLU()) | |||
| self.conv1d = nn.Sequential( | |||
| nn.Conv1d( | |||
| r_channel * input_dim, in_channel, kernel_size=9, padding=4), | |||
| nn.BatchNorm1d(in_channel), nn.ReLU()) | |||
| self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) | |||
| self.conv2 = nn.Sequential( | |||
| nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), | |||
| nn.BatchNorm2d(in_channel), nn.ReLU()) | |||
| def forward(self, inputs): | |||
| ''' | |||
| inputs should be [Batch, Ca, Dim, Time] | |||
| ''' | |||
| # T-F attention | |||
| conv1_out = self.conv1(inputs) | |||
| B, C, D, T = conv1_out.size() | |||
| reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) | |||
| conv1d_out = self.conv1d(reshape1_out) | |||
| conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) | |||
| # now is also [B,C,D,T] | |||
| att_out = conv1d_out * inputs | |||
| # tranpose to [B,C,T,D] | |||
| att_out = torch.transpose(att_out, 2, 3) | |||
| freqfc_out = self.freq_fc(att_out) | |||
| att_out = torch.transpose(freqfc_out, 2, 3) | |||
| cat_out = torch.cat([att_out, inputs], 1) | |||
| outputs = self.conv2(cat_out) | |||
| return outputs | |||
| @MODELS.register_module( | |||
| Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k) | |||
| class FRCRNModel(Model): | |||
| r""" A decorator of FRCRN for integrating into modelscope framework """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the frcrn model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self._model = FRCRN(*args, **kwargs) | |||
| model_bin_file = os.path.join(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| if os.path.exists(model_bin_file): | |||
| checkpoint = torch.load(model_bin_file) | |||
| self._model.load_state_dict(checkpoint, strict=False) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| output = self._model.forward(input) | |||
| return { | |||
| 'spec_l1': output[0], | |||
| 'wav_l1': output[1], | |||
| 'mask_l1': output[2], | |||
| 'spec_l2': output[3], | |||
| 'wav_l2': output[4], | |||
| 'mask_l2': output[5] | |||
| } | |||
| def to(self, *args, **kwargs): | |||
| self._model = self._model.to(*args, **kwargs) | |||
| return self | |||
| def eval(self): | |||
| self._model = self._model.train(False) | |||
| return self | |||
| class FRCRN(nn.Module): | |||
| r""" Frequency Recurrent CRN """ | |||
| def __init__(self, | |||
| complex, | |||
| model_complexity, | |||
| model_depth, | |||
| log_amp, | |||
| padding_mode, | |||
| win_len=400, | |||
| win_inc=100, | |||
| fft_len=512, | |||
| win_type='hanning'): | |||
| r""" | |||
| Args: | |||
| complex: Whether to use complex networks. | |||
| model_complexity: define the model complexity with the number of layers | |||
| model_depth: Only two options are available : 10, 20 | |||
| log_amp: Whether to use log amplitude to estimate signals | |||
| padding_mode: Encoder's convolution filter. 'zeros', 'reflect' | |||
| win_len: length of window used for defining one frame of sample points | |||
| win_inc: length of window shifting (equivalent to hop_size) | |||
| fft_len: number of Short Time Fourier Transform (STFT) points | |||
| win_type: windowing type used in STFT, eg. 'hanning', 'hamming' | |||
| """ | |||
| super().__init__() | |||
| self.feat_dim = fft_len // 2 + 1 | |||
| self.win_len = win_len | |||
| self.win_inc = win_inc | |||
| self.fft_len = fft_len | |||
| self.win_type = win_type | |||
| fix = True | |||
| self.stft = ConvSTFT( | |||
| self.win_len, | |||
| self.win_inc, | |||
| self.fft_len, | |||
| self.win_type, | |||
| feature_type='complex', | |||
| fix=fix) | |||
| self.istft = ConviSTFT( | |||
| self.win_len, | |||
| self.win_inc, | |||
| self.fft_len, | |||
| self.win_type, | |||
| feature_type='complex', | |||
| fix=fix) | |||
| self.unet = UNet( | |||
| 1, | |||
| complex=complex, | |||
| model_complexity=model_complexity, | |||
| model_depth=model_depth, | |||
| padding_mode=padding_mode) | |||
| self.unet2 = UNet( | |||
| 1, | |||
| complex=complex, | |||
| model_complexity=model_complexity, | |||
| model_depth=model_depth, | |||
| padding_mode=padding_mode) | |||
| def forward(self, inputs): | |||
| out_list = [] | |||
| # [B, D*2, T] | |||
| cmp_spec = self.stft(inputs) | |||
| # [B, 1, D*2, T] | |||
| cmp_spec = torch.unsqueeze(cmp_spec, 1) | |||
| # to [B, 2, D, T] real_part/imag_part | |||
| cmp_spec = torch.cat([ | |||
| cmp_spec[:, :, :self.feat_dim, :], | |||
| cmp_spec[:, :, self.feat_dim:, :], | |||
| ], 1) | |||
| # [B, 2, D, T] | |||
| cmp_spec = torch.unsqueeze(cmp_spec, 4) | |||
| # [B, 1, D, T, 2] | |||
| cmp_spec = torch.transpose(cmp_spec, 1, 4) | |||
| unet1_out = self.unet(cmp_spec) | |||
| cmp_mask1 = torch.tanh(unet1_out) | |||
| unet2_out = self.unet2(unet1_out) | |||
| cmp_mask2 = torch.tanh(unet2_out) | |||
| est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) | |||
| out_list.append(est_spec) | |||
| out_list.append(est_wav) | |||
| out_list.append(est_mask) | |||
| cmp_mask2 = cmp_mask2 + cmp_mask1 | |||
| est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) | |||
| out_list.append(est_spec) | |||
| out_list.append(est_wav) | |||
| out_list.append(est_mask) | |||
| return out_list | |||
| def apply_mask(self, cmp_spec, cmp_mask): | |||
| est_spec = torch.cat([ | |||
| cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0] | |||
| - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1], | |||
| cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1] | |||
| + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0] | |||
| ], 1) | |||
| est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1) | |||
| cmp_mask = torch.squeeze(cmp_mask, 1) | |||
| cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1) | |||
| est_wav = self.istft(est_spec) | |||
| est_wav = torch.squeeze(est_wav, 1) | |||
| return est_spec, est_wav, cmp_mask | |||
| def get_params(self, weight_decay=0.0): | |||
| # add L2 penalty | |||
| weights, biases = [], [] | |||
| for name, param in self.named_parameters(): | |||
| if 'bias' in name: | |||
| biases += [param] | |||
| else: | |||
| weights += [param] | |||
| params = [{ | |||
| 'params': weights, | |||
| 'weight_decay': weight_decay, | |||
| }, { | |||
| 'params': biases, | |||
| 'weight_decay': 0.0, | |||
| }] | |||
| return params | |||
| def loss(self, noisy, labels, out_list, mode='Mix'): | |||
| if mode == 'SiSNR': | |||
| count = 0 | |||
| while count < len(out_list): | |||
| est_spec = out_list[count] | |||
| count = count + 1 | |||
| est_wav = out_list[count] | |||
| count = count + 1 | |||
| est_mask = out_list[count] | |||
| count = count + 1 | |||
| if count != 3: | |||
| loss = self.loss_1layer(noisy, est_spec, est_wav, labels, | |||
| est_mask, mode) | |||
| return loss | |||
| elif mode == 'Mix': | |||
| count = 0 | |||
| while count < len(out_list): | |||
| est_spec = out_list[count] | |||
| count = count + 1 | |||
| est_wav = out_list[count] | |||
| count = count + 1 | |||
| est_mask = out_list[count] | |||
| count = count + 1 | |||
| if count != 3: | |||
| amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( | |||
| noisy, est_spec, est_wav, labels, est_mask, mode) | |||
| loss = amp_loss + phase_loss + SiSNR_loss | |||
| return loss, amp_loss, phase_loss | |||
| def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): | |||
| r""" Compute the loss by mode | |||
| mode == 'Mix' | |||
| est: [B, F*2, T] | |||
| labels: [B, F*2,T] | |||
| mode == 'SiSNR' | |||
| est: [B, T] | |||
| labels: [B, T] | |||
| """ | |||
| if mode == 'SiSNR': | |||
| if labels.dim() == 3: | |||
| labels = torch.squeeze(labels, 1) | |||
| if est_wav.dim() == 3: | |||
| est_wav = torch.squeeze(est_wav, 1) | |||
| return -si_snr(est_wav, labels) | |||
| elif mode == 'Mix': | |||
| if labels.dim() == 3: | |||
| labels = torch.squeeze(labels, 1) | |||
| if est_wav.dim() == 3: | |||
| est_wav = torch.squeeze(est_wav, 1) | |||
| SiSNR_loss = -si_snr(est_wav, labels) | |||
| b, d, t = est.size() | |||
| S = self.stft(labels) | |||
| Sr = S[:, :self.feat_dim, :] | |||
| Si = S[:, self.feat_dim:, :] | |||
| Y = self.stft(noisy) | |||
| Yr = Y[:, :self.feat_dim, :] | |||
| Yi = Y[:, self.feat_dim:, :] | |||
| Y_pow = Yr**2 + Yi**2 | |||
| gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8), | |||
| (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1) | |||
| gth_mask[gth_mask > 2] = 1 | |||
| gth_mask[gth_mask < -2] = -1 | |||
| amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :], | |||
| cmp_mask[:, :self.feat_dim, :]) * d | |||
| phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :], | |||
| cmp_mask[:, self.feat_dim:, :]) * d | |||
| return amp_loss, phase_loss, SiSNR_loss | |||
| def l2_norm(s1, s2): | |||
| norm = torch.sum(s1 * s2, -1, keepdim=True) | |||
| return norm | |||
| def si_snr(s1, s2, eps=1e-8): | |||
| s1_s2_norm = l2_norm(s1, s2) | |||
| s2_s2_norm = l2_norm(s2, s2) | |||
| s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 | |||
| e_nosie = s1 - s_target | |||
| target_norm = l2_norm(s_target, s_target) | |||
| noise_norm = l2_norm(e_nosie, e_nosie) | |||
| snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) | |||
| return torch.mean(snr) | |||
| @@ -0,0 +1,26 @@ | |||
| import torch | |||
| from torch import nn | |||
| class SELayer(nn.Module): | |||
| def __init__(self, channel, reduction=16): | |||
| super(SELayer, self).__init__() | |||
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
| self.fc_r = nn.Sequential( | |||
| nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), | |||
| nn.Linear(channel // reduction, channel), nn.Sigmoid()) | |||
| self.fc_i = nn.Sequential( | |||
| nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), | |||
| nn.Linear(channel // reduction, channel), nn.Sigmoid()) | |||
| def forward(self, x): | |||
| b, c, _, _, _ = x.size() | |||
| x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) | |||
| x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) | |||
| y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view( | |||
| b, c, 1, 1, 1) | |||
| y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view( | |||
| b, c, 1, 1, 1) | |||
| y = torch.cat([y_r, y_i], 4) | |||
| return x * y | |||
| @@ -0,0 +1,269 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from . import complex_nn | |||
| from .se_module_complex import SELayer | |||
| class Encoder(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding=None, | |||
| complex=False, | |||
| padding_mode='zeros'): | |||
| super().__init__() | |||
| if padding is None: | |||
| padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding | |||
| if complex: | |||
| conv = complex_nn.ComplexConv2d | |||
| bn = complex_nn.ComplexBatchNorm2d | |||
| else: | |||
| conv = nn.Conv2d | |||
| bn = nn.BatchNorm2d | |||
| self.conv = conv( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| padding_mode=padding_mode) | |||
| self.bn = bn(out_channels) | |||
| self.relu = nn.LeakyReLU(inplace=True) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| x = self.bn(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class Decoder(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding=(0, 0), | |||
| complex=False): | |||
| super().__init__() | |||
| if complex: | |||
| tconv = complex_nn.ComplexConvTranspose2d | |||
| bn = complex_nn.ComplexBatchNorm2d | |||
| else: | |||
| tconv = nn.ConvTranspose2d | |||
| bn = nn.BatchNorm2d | |||
| self.transconv = tconv( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding) | |||
| self.bn = bn(out_channels) | |||
| self.relu = nn.LeakyReLU(inplace=True) | |||
| def forward(self, x): | |||
| x = self.transconv(x) | |||
| x = self.bn(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class UNet(nn.Module): | |||
| def __init__(self, | |||
| input_channels=1, | |||
| complex=False, | |||
| model_complexity=45, | |||
| model_depth=20, | |||
| padding_mode='zeros'): | |||
| super().__init__() | |||
| if complex: | |||
| model_complexity = int(model_complexity // 1.414) | |||
| self.set_size( | |||
| model_complexity=model_complexity, | |||
| input_channels=input_channels, | |||
| model_depth=model_depth) | |||
| self.encoders = [] | |||
| self.model_length = model_depth // 2 | |||
| self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128) | |||
| self.se_layers_enc = [] | |||
| self.fsmn_enc = [] | |||
| for i in range(self.model_length): | |||
| fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) | |||
| self.add_module('fsmn_enc{}'.format(i), fsmn_enc) | |||
| self.fsmn_enc.append(fsmn_enc) | |||
| module = Encoder( | |||
| self.enc_channels[i], | |||
| self.enc_channels[i + 1], | |||
| kernel_size=self.enc_kernel_sizes[i], | |||
| stride=self.enc_strides[i], | |||
| padding=self.enc_paddings[i], | |||
| complex=complex, | |||
| padding_mode=padding_mode) | |||
| self.add_module('encoder{}'.format(i), module) | |||
| self.encoders.append(module) | |||
| se_layer_enc = SELayer(self.enc_channels[i + 1], 8) | |||
| self.add_module('se_layer_enc{}'.format(i), se_layer_enc) | |||
| self.se_layers_enc.append(se_layer_enc) | |||
| self.decoders = [] | |||
| self.fsmn_dec = [] | |||
| self.se_layers_dec = [] | |||
| for i in range(self.model_length): | |||
| fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) | |||
| self.add_module('fsmn_dec{}'.format(i), fsmn_dec) | |||
| self.fsmn_dec.append(fsmn_dec) | |||
| module = Decoder( | |||
| self.dec_channels[i] * 2, | |||
| self.dec_channels[i + 1], | |||
| kernel_size=self.dec_kernel_sizes[i], | |||
| stride=self.dec_strides[i], | |||
| padding=self.dec_paddings[i], | |||
| complex=complex) | |||
| self.add_module('decoder{}'.format(i), module) | |||
| self.decoders.append(module) | |||
| if i < self.model_length - 1: | |||
| se_layer_dec = SELayer(self.dec_channels[i + 1], 8) | |||
| self.add_module('se_layer_dec{}'.format(i), se_layer_dec) | |||
| self.se_layers_dec.append(se_layer_dec) | |||
| if complex: | |||
| conv = complex_nn.ComplexConv2d | |||
| else: | |||
| conv = nn.Conv2d | |||
| linear = conv(self.dec_channels[-1], 1, 1) | |||
| self.add_module('linear', linear) | |||
| self.complex = complex | |||
| self.padding_mode = padding_mode | |||
| self.decoders = nn.ModuleList(self.decoders) | |||
| self.encoders = nn.ModuleList(self.encoders) | |||
| self.se_layers_enc = nn.ModuleList(self.se_layers_enc) | |||
| self.se_layers_dec = nn.ModuleList(self.se_layers_dec) | |||
| self.fsmn_enc = nn.ModuleList(self.fsmn_enc) | |||
| self.fsmn_dec = nn.ModuleList(self.fsmn_dec) | |||
| def forward(self, inputs): | |||
| x = inputs | |||
| # go down | |||
| xs = [] | |||
| xs_se = [] | |||
| xs_se.append(x) | |||
| for i, encoder in enumerate(self.encoders): | |||
| xs.append(x) | |||
| if i > 0: | |||
| x = self.fsmn_enc[i](x) | |||
| x = encoder(x) | |||
| xs_se.append(self.se_layers_enc[i](x)) | |||
| # xs : x0=input x1 ... x9 | |||
| x = self.fsmn(x) | |||
| p = x | |||
| for i, decoder in enumerate(self.decoders): | |||
| p = decoder(p) | |||
| if i < self.model_length - 1: | |||
| p = self.fsmn_dec[i](p) | |||
| if i == self.model_length - 1: | |||
| break | |||
| if i < self.model_length - 2: | |||
| p = self.se_layers_dec[i](p) | |||
| p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1) | |||
| # cmp_spec: [12, 1, 513, 64, 2] | |||
| cmp_spec = self.linear(p) | |||
| return cmp_spec | |||
| def set_size(self, model_complexity, model_depth=20, input_channels=1): | |||
| if model_depth == 14: | |||
| self.enc_channels = [ | |||
| input_channels, 128, 128, 128, 128, 128, 128, 128 | |||
| ] | |||
| self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), | |||
| (5, 2), (2, 2)] | |||
| self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), | |||
| (2, 1)] | |||
| self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), | |||
| (0, 1), (0, 1)] | |||
| self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] | |||
| self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), | |||
| (5, 2), (5, 2)] | |||
| self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), | |||
| (2, 1)] | |||
| self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), | |||
| (0, 1), (0, 1)] | |||
| elif model_depth == 10: | |||
| self.enc_channels = [ | |||
| input_channels, | |||
| 16, | |||
| 32, | |||
| 64, | |||
| 128, | |||
| 256, | |||
| ] | |||
| self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] | |||
| self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |||
| self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |||
| self.dec_channels = [128, 128, 64, 32, 16, 1] | |||
| self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] | |||
| self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |||
| self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |||
| elif model_depth == 20: | |||
| self.enc_channels = [ | |||
| input_channels, model_complexity, model_complexity, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, 128 | |||
| ] | |||
| self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), | |||
| (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] | |||
| self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1), | |||
| (2, 2), (2, 1), (2, 2), (2, 1)] | |||
| self.enc_paddings = [ | |||
| (3, 0), | |||
| (0, 3), | |||
| None, # (0, 2), | |||
| None, | |||
| None, # (3,1), | |||
| None, # (3,1), | |||
| None, # (1,2), | |||
| None, | |||
| None, | |||
| None | |||
| ] | |||
| self.dec_channels = [ | |||
| 0, model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2, model_complexity * 2, | |||
| model_complexity * 2 | |||
| ] | |||
| self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), | |||
| (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] | |||
| self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), | |||
| (2, 1), (2, 2), (1, 1), (1, 1)] | |||
| self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), | |||
| (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] | |||
| else: | |||
| raise ValueError('Unknown model depth : {}'.format(model_depth)) | |||
| @@ -1,4 +1,5 @@ | |||
| from .audio import LinearAECPipeline | |||
| from .audio.ans_pipeline import ANSPipeline | |||
| from .base import Pipeline | |||
| from .builder import pipeline | |||
| from .cv import * # noqa F403 | |||
| @@ -0,0 +1,117 @@ | |||
| import os.path | |||
| from typing import Any, Dict | |||
| import librosa | |||
| import numpy as np | |||
| import soundfile as sf | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| def audio_norm(x): | |||
| rms = (x**2).mean()**0.5 | |||
| scalar = 10**(-25 / 20) / rms | |||
| x = x * scalar | |||
| pow_x = x**2 | |||
| avg_pow_x = pow_x.mean() | |||
| rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 | |||
| scalarx = 10**(-25 / 20) / rmsx | |||
| x = x * scalarx | |||
| return x | |||
| @PIPELINES.register_module( | |||
| Tasks.speech_signal_process, | |||
| module_name=Pipelines.speech_frcrn_ans_cirm_16k) | |||
| class ANSPipeline(Pipeline): | |||
| r"""ANS (Acoustic Noise Suppression) Inference Pipeline . | |||
| When invoke the class with pipeline.__call__(), it accept only one parameter: | |||
| inputs(str): the path of wav file | |||
| """ | |||
| SAMPLE_RATE = 16000 | |||
| def __init__(self, model): | |||
| r""" | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| self.device = torch.device( | |||
| 'cuda' if torch.cuda.is_available() else 'cpu') | |||
| self.model = self.model.to(self.device) | |||
| self.model.eval() | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \ | |||
| f'Input file do not exists: {inputs}' | |||
| data1, fs = sf.read(inputs) | |||
| data1 = audio_norm(data1) | |||
| if fs != self.SAMPLE_RATE: | |||
| data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) | |||
| if len(data1.shape) > 1: | |||
| data1 = data1[:, 0] | |||
| data = data1.astype(np.float32) | |||
| inputs = np.reshape(data, [1, data.shape[0]]) | |||
| return {'ndarray': inputs, 'nsamples': data.shape[0]} | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| ndarray = inputs['ndarray'] | |||
| nsamples = inputs['nsamples'] | |||
| decode_do_segement = False | |||
| window = 16000 | |||
| stride = int(window * 0.75) | |||
| print('inputs:{}'.format(ndarray.shape)) | |||
| b, t = ndarray.shape # size() | |||
| if t > window * 120: | |||
| decode_do_segement = True | |||
| if t < window: | |||
| ndarray = np.concatenate( | |||
| [ndarray, np.zeros((ndarray.shape[0], window - t))], 1) | |||
| elif t < window + stride: | |||
| padding = window + stride - t | |||
| print('padding: {}'.format(padding)) | |||
| ndarray = np.concatenate( | |||
| [ndarray, np.zeros((ndarray.shape[0], padding))], 1) | |||
| else: | |||
| if (t - window) % stride != 0: | |||
| padding = t - (t - window) // stride * stride | |||
| print('padding: {}'.format(padding)) | |||
| ndarray = np.concatenate( | |||
| [ndarray, np.zeros((ndarray.shape[0], padding))], 1) | |||
| print('inputs after padding:{}'.format(ndarray.shape)) | |||
| with torch.no_grad(): | |||
| ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device) | |||
| b, t = ndarray.shape | |||
| if decode_do_segement: | |||
| outputs = np.zeros(t) | |||
| give_up_length = (window - stride) // 2 | |||
| current_idx = 0 | |||
| while current_idx + window <= t: | |||
| print('current_idx: {}'.format(current_idx)) | |||
| tmp_input = ndarray[:, current_idx:current_idx + window] | |||
| tmp_output = self.model( | |||
| tmp_input, )['wav_l2'][0].cpu().numpy() | |||
| end_index = current_idx + window - give_up_length | |||
| if current_idx == 0: | |||
| outputs[current_idx: | |||
| end_index] = tmp_output[:-give_up_length] | |||
| else: | |||
| outputs[current_idx | |||
| + give_up_length:end_index] = tmp_output[ | |||
| give_up_length:-give_up_length] | |||
| current_idx += stride | |||
| else: | |||
| outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() | |||
| return {'output_pcm': outputs[:nsamples]} | |||
| def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| if 'output_path' in kwargs.keys(): | |||
| sf.write(kwargs['output_path'], inputs['output_pcm'], | |||
| self.SAMPLE_RATE) | |||
| return inputs | |||
| @@ -16,6 +16,7 @@ protobuf>3,<=3.20 | |||
| ptflops | |||
| PyWavelets>=1.0.0 | |||
| scikit-learn | |||
| SoundFile>0.10 | |||
| sox | |||
| tensorboard | |||
| tensorflow==1.15.* | |||
| @@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl | |||
| '?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' | |||
| AEC_LIB_FILE = 'libmitaec_pyio.so' | |||
| NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' | |||
| NOISE_SPEECH_FILE = 'speech_with_noise.wav' | |||
| def download(remote_path, local_path): | |||
| local_dir = os.path.dirname(local_path) | |||
| @@ -30,23 +33,40 @@ def download(remote_path, local_path): | |||
| class SpeechSignalProcessTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
| pass | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_aec(self): | |||
| # A temporary hack to provide c++ lib. Download it first. | |||
| download(AEC_LIB_URL, AEC_LIB_FILE) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| # Download audio files | |||
| download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||
| download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
| input = { | |||
| 'nearend_mic': NEAREND_MIC_FILE, | |||
| 'farend_speech': FAREND_SPEECH_FILE | |||
| } | |||
| aec = pipeline( | |||
| Tasks.speech_signal_process, | |||
| model=self.model_id, | |||
| model=model_id, | |||
| pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) | |||
| aec(input, output_path='output.wav') | |||
| output_path = os.path.abspath('output.wav') | |||
| aec(input, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_ans(self): | |||
| # Download audio files | |||
| download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
| ans = pipeline( | |||
| Tasks.speech_signal_process, | |||
| model=model_id, | |||
| pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) | |||
| output_path = os.path.abspath('output.wav') | |||
| ans(NOISE_SPEECH_FILE, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| if __name__ == '__main__': | |||