From 04b7eba285dae7026a08b0136ccea7ba31319f6b Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Tue, 28 Jun 2022 14:41:08 +0800 Subject: [PATCH] [to #42322933] Merge ANS pipeline into master Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9178339 * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * feat: add unittest for ANS pipeline * Merge branch 'master' into dev/ans * add new SoundFile to audio dependency * Merge branch 'master' into dev/ans * use ANS pipeline name from metainfo * Merge branch 'master' into dev/ans * chore: update docstring of ANS module * Merge branch 'master' into dev/ans * refactor: use names from metainfo * refactor: enable ans unittest * refactor: add more log message in unittest --- modelscope/metainfo.py | 2 + modelscope/models/__init__.py | 1 + .../models/audio/{layers => aec}/__init__.py | 0 .../audio/{network => aec/layers}/__init__.py | 0 .../audio/{ => aec}/layers/activations.py | 0 .../{ => aec}/layers/affine_transform.py | 0 .../audio/{ => aec}/layers/deep_fsmn.py | 0 .../audio/{ => aec}/layers/layer_base.py | 0 .../audio/{ => aec}/layers/uni_deep_fsmn.py | 0 .../models/audio/aec/network/__init__.py | 0 .../models/audio/{ => aec}/network/loss.py | 0 .../{ => aec}/network/modulation_loss.py | 0 .../models/audio/{ => aec}/network/se_net.py | 0 modelscope/models/audio/ans/__init__.py | 0 modelscope/models/audio/ans/complex_nn.py | 248 ++++++++++++++ modelscope/models/audio/ans/conv_stft.py | 112 +++++++ modelscope/models/audio/ans/frcrn.py | 309 ++++++++++++++++++ .../models/audio/ans/se_module_complex.py | 26 ++ modelscope/models/audio/ans/unet.py | 269 +++++++++++++++ modelscope/pipelines/__init__.py | 1 + modelscope/pipelines/audio/ans_pipeline.py | 117 +++++++ requirements/audio.txt | 1 + tests/pipelines/test_speech_signal_process.py | 32 +- 23 files changed, 1112 insertions(+), 6 deletions(-) rename modelscope/models/audio/{layers => aec}/__init__.py (100%) rename modelscope/models/audio/{network => aec/layers}/__init__.py (100%) rename modelscope/models/audio/{ => aec}/layers/activations.py (100%) rename modelscope/models/audio/{ => aec}/layers/affine_transform.py (100%) rename modelscope/models/audio/{ => aec}/layers/deep_fsmn.py (100%) rename modelscope/models/audio/{ => aec}/layers/layer_base.py (100%) rename modelscope/models/audio/{ => aec}/layers/uni_deep_fsmn.py (100%) create mode 100644 modelscope/models/audio/aec/network/__init__.py rename modelscope/models/audio/{ => aec}/network/loss.py (100%) rename modelscope/models/audio/{ => aec}/network/modulation_loss.py (100%) rename modelscope/models/audio/{ => aec}/network/se_net.py (100%) create mode 100644 modelscope/models/audio/ans/__init__.py create mode 100644 modelscope/models/audio/ans/complex_nn.py create mode 100644 modelscope/models/audio/ans/conv_stft.py create mode 100644 modelscope/models/audio/ans/frcrn.py create mode 100644 modelscope/models/audio/ans/se_module_complex.py create mode 100644 modelscope/models/audio/ans/unet.py create mode 100644 modelscope/pipelines/audio/ans_pipeline.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 9fad45e2..eda590ac 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -21,6 +21,7 @@ class Models(object): sambert_hifi_16k = 'sambert-hifi-16k' generic_tts_frontend = 'generic-tts-frontend' hifigan16k = 'hifigan16k' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' # multi-modal models @@ -55,6 +56,7 @@ class Pipelines(object): # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' kws_kwsbp = 'kws-kwsbp' # multi-modal tasks diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index ebf81c32..816c44e2 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio.ans.frcrn import FRCRNModel from .audio.kws import GenericKeyWordSpotting from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k diff --git a/modelscope/models/audio/layers/__init__.py b/modelscope/models/audio/aec/__init__.py similarity index 100% rename from modelscope/models/audio/layers/__init__.py rename to modelscope/models/audio/aec/__init__.py diff --git a/modelscope/models/audio/network/__init__.py b/modelscope/models/audio/aec/layers/__init__.py similarity index 100% rename from modelscope/models/audio/network/__init__.py rename to modelscope/models/audio/aec/layers/__init__.py diff --git a/modelscope/models/audio/layers/activations.py b/modelscope/models/audio/aec/layers/activations.py similarity index 100% rename from modelscope/models/audio/layers/activations.py rename to modelscope/models/audio/aec/layers/activations.py diff --git a/modelscope/models/audio/layers/affine_transform.py b/modelscope/models/audio/aec/layers/affine_transform.py similarity index 100% rename from modelscope/models/audio/layers/affine_transform.py rename to modelscope/models/audio/aec/layers/affine_transform.py diff --git a/modelscope/models/audio/layers/deep_fsmn.py b/modelscope/models/audio/aec/layers/deep_fsmn.py similarity index 100% rename from modelscope/models/audio/layers/deep_fsmn.py rename to modelscope/models/audio/aec/layers/deep_fsmn.py diff --git a/modelscope/models/audio/layers/layer_base.py b/modelscope/models/audio/aec/layers/layer_base.py similarity index 100% rename from modelscope/models/audio/layers/layer_base.py rename to modelscope/models/audio/aec/layers/layer_base.py diff --git a/modelscope/models/audio/layers/uni_deep_fsmn.py b/modelscope/models/audio/aec/layers/uni_deep_fsmn.py similarity index 100% rename from modelscope/models/audio/layers/uni_deep_fsmn.py rename to modelscope/models/audio/aec/layers/uni_deep_fsmn.py diff --git a/modelscope/models/audio/aec/network/__init__.py b/modelscope/models/audio/aec/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/network/loss.py b/modelscope/models/audio/aec/network/loss.py similarity index 100% rename from modelscope/models/audio/network/loss.py rename to modelscope/models/audio/aec/network/loss.py diff --git a/modelscope/models/audio/network/modulation_loss.py b/modelscope/models/audio/aec/network/modulation_loss.py similarity index 100% rename from modelscope/models/audio/network/modulation_loss.py rename to modelscope/models/audio/aec/network/modulation_loss.py diff --git a/modelscope/models/audio/network/se_net.py b/modelscope/models/audio/aec/network/se_net.py similarity index 100% rename from modelscope/models/audio/network/se_net.py rename to modelscope/models/audio/aec/network/se_net.py diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py new file mode 100644 index 00000000..69dec41e --- /dev/null +++ b/modelscope/models/audio/ans/complex_nn.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UniDeepFsmn(nn.Module): + + def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None): + super(UniDeepFsmn, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if lorder is None: + return + + self.lorder = lorder + self.hidden_size = hidden_size + + self.linear = nn.Linear(input_dim, hidden_size) + + self.project = nn.Linear(hidden_size, output_dim, bias=False) + + self.conv1 = nn.Conv2d( + output_dim, + output_dim, [lorder, 1], [1, 1], + groups=output_dim, + bias=False) + + def forward(self, input): + r""" + + Args: + input: torch with shape: batch (b) x sequence(T) x feature (h) + + Returns: + batch (b) x channel (c) x sequence(T) x feature (h) + """ + f1 = F.relu(self.linear(input)) + + p1 = self.project(f1) + + x = torch.unsqueeze(p1, 1) + # x: batch (b) x channel (c) x sequence(T) x feature (h) + x_per = x.permute(0, 3, 2, 1) + # x_per: batch (b) x feature (h) x sequence(T) x channel (c) + y = F.pad(x_per, [0, 0, self.lorder - 1, 0]) + + out = x_per + self.conv1(y) + + out1 = out.permute(0, 3, 2, 1) + # out1: batch (b) x channel (c) x sequence(T) x feature (h) + return input + out1.squeeze() + + +class ComplexUniDeepFsmn(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn, self).__init__() + + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + + Returns: + [batch, feature, sequence, 2], eg: [6, 99, 1024, 2] + """ + # + b, c, h, T, d = x.size() + x = torch.reshape(x, (b, c * h, T, d)) + # x: [b,h,T,2], [6, 256, 106, 2] + x = torch.transpose(x, 1, 2) + # x: [b,T,h,2], [6, 106, 256, 2] + + real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # GRU output: [99, 6, 128] + real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1) + imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1) + # output: [b,T,h,2], [99, 6, 1024, 2] + output = torch.stack((real, imaginary), dim=-1) + + # output: [b,h,T,2], [6, 99, 1024, 2] + output = torch.transpose(output, 1, 2) + output = torch.reshape(output, (b, c, h, T, d)) + + return output + + +class ComplexUniDeepFsmn_L1(nn.Module): + + def __init__(self, nIn, nHidden=128, nOut=128): + super(ComplexUniDeepFsmn_L1, self).__init__() + self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden) + + def forward(self, x): + r""" + + Args: + x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2] + """ + b, c, h, T, d = x.size() + # x : [b,T,h,c,2] + x = torch.transpose(x, 1, 3) + x = torch.reshape(x, (b * T, h, c, d)) + + real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1]) + imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0]) + # output: [b*T,h,c,2], [6*106, h, 256, 2] + output = torch.stack((real, imaginary), dim=-1) + + output = torch.reshape(output, (b, T, h, c, d)) + output = torch.transpose(output, 1, 3) + return output + + +class ComplexConv2d(nn.Module): + # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.conv_re = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + self.conv_im = nn.Conv2d( + in_channel, + out_channel, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs) + + def forward(self, x): + r""" + + Args: + x: torch with shape: [batch,channel,axis1,axis2,2] + """ + real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) + imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexConvTranspose2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + output_padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs): + super().__init__() + + # Model components + self.tconv_re = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + self.tconv_im = nn.ConvTranspose2d( + in_channel, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs) + + def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] + real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) + imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexBatchNorm2d(nn.Module): + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + **kwargs): + super().__init__() + self.bn_re = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + self.bn_im = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs) + + def forward(self, x): + real = self.bn_re(x[..., 0]) + imag = self.bn_im(x[..., 1]) + output = torch.stack((real, imag), dim=-1) + return output diff --git a/modelscope/models/audio/ans/conv_stft.py b/modelscope/models/audio/ans/conv_stft.py new file mode 100644 index 00000000..a47d7817 --- /dev/null +++ b/modelscope/models/audio/ans/conv_stft.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window + + +def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): + if win_type == 'None' or win_type is None: + window = np.ones(win_len) + else: + window = get_window(win_type, win_len, fftbins=True)**0.5 + + N = fft_len + fourier_basis = np.fft.rfft(np.eye(N))[:win_len] + real_kernel = np.real(fourier_basis) + imag_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, imag_kernel], 1).T + + if invers: + kernel = np.linalg.pinv(kernel).T + + kernel = kernel * window + kernel = kernel[:, None, :] + return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy( + window[None, :, None].astype(np.float32)) + + +class ConvSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConvSTFT, self).__init__() + + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + + kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.stride = win_inc + self.win_len = win_len + self.dim = self.fft_len + + def forward(self, inputs): + if inputs.dim() == 2: + inputs = torch.unsqueeze(inputs, 1) + + outputs = F.conv1d(inputs, self.weight, stride=self.stride) + + if self.feature_type == 'complex': + return outputs + else: + dim = self.dim // 2 + 1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + mags = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag, real) + return mags, phase + + +class ConviSTFT(nn.Module): + + def __init__(self, + win_len, + win_inc, + fft_len=None, + win_type='hamming', + feature_type='real', + fix=True): + super(ConviSTFT, self).__init__() + if fft_len is None: + self.fft_len = np.int(2**np.ceil(np.log2(win_len))) + else: + self.fft_len = fft_len + kernel, window = init_kernels( + win_len, win_inc, self.fft_len, win_type, invers=True) + self.weight = nn.Parameter(kernel, requires_grad=(not fix)) + self.feature_type = feature_type + self.win_type = win_type + self.win_len = win_len + self.win_inc = win_inc + self.stride = win_inc + self.dim = self.fft_len + self.register_buffer('window', window) + self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) + + def forward(self, inputs, phase=None): + """ + Args: + inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) + phase: [B, N//2+1, T] (if not none) + """ + + if phase is not None: + real = inputs * torch.cos(phase) + imag = inputs * torch.sin(phase) + inputs = torch.cat([real, imag], 1) + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + + # this is from torch-stft: https://github.com/pseeth/torch-stft + t = self.window.repeat(1, 1, inputs.size(-1))**2 + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs / (coff + 1e-8) + return outputs diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py new file mode 100644 index 00000000..c56b8773 --- /dev/null +++ b/modelscope/models/audio/ans/frcrn.py @@ -0,0 +1,309 @@ +import os +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from ...base import Model, Tensor +from .conv_stft import ConviSTFT, ConvSTFT +from .unet import UNet + + +class FTB(nn.Module): + + def __init__(self, input_dim=257, in_channel=9, r_channel=5): + + super(FTB, self).__init__() + self.in_channel = in_channel + self.conv1 = nn.Sequential( + nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), + nn.BatchNorm2d(r_channel), nn.ReLU()) + + self.conv1d = nn.Sequential( + nn.Conv1d( + r_channel * input_dim, in_channel, kernel_size=9, padding=4), + nn.BatchNorm1d(in_channel), nn.ReLU()) + self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) + + self.conv2 = nn.Sequential( + nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), + nn.BatchNorm2d(in_channel), nn.ReLU()) + + def forward(self, inputs): + ''' + inputs should be [Batch, Ca, Dim, Time] + ''' + # T-F attention + conv1_out = self.conv1(inputs) + B, C, D, T = conv1_out.size() + reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) + conv1d_out = self.conv1d(reshape1_out) + conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) + + # now is also [B,C,D,T] + att_out = conv1d_out * inputs + + # tranpose to [B,C,T,D] + att_out = torch.transpose(att_out, 2, 3) + freqfc_out = self.freq_fc(att_out) + att_out = torch.transpose(freqfc_out, 2, 3) + + cat_out = torch.cat([att_out, inputs], 1) + outputs = self.conv2(cat_out) + return outputs + + +@MODELS.register_module( + Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k) +class FRCRNModel(Model): + r""" A decorator of FRCRN for integrating into modelscope framework """ + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the frcrn model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + self._model = FRCRN(*args, **kwargs) + model_bin_file = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load(model_bin_file) + self._model.load_state_dict(checkpoint, strict=False) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + output = self._model.forward(input) + return { + 'spec_l1': output[0], + 'wav_l1': output[1], + 'mask_l1': output[2], + 'spec_l2': output[3], + 'wav_l2': output[4], + 'mask_l2': output[5] + } + + def to(self, *args, **kwargs): + self._model = self._model.to(*args, **kwargs) + return self + + def eval(self): + self._model = self._model.train(False) + return self + + +class FRCRN(nn.Module): + r""" Frequency Recurrent CRN """ + + def __init__(self, + complex, + model_complexity, + model_depth, + log_amp, + padding_mode, + win_len=400, + win_inc=100, + fft_len=512, + win_type='hanning'): + r""" + Args: + complex: Whether to use complex networks. + model_complexity: define the model complexity with the number of layers + model_depth: Only two options are available : 10, 20 + log_amp: Whether to use log amplitude to estimate signals + padding_mode: Encoder's convolution filter. 'zeros', 'reflect' + win_len: length of window used for defining one frame of sample points + win_inc: length of window shifting (equivalent to hop_size) + fft_len: number of Short Time Fourier Transform (STFT) points + win_type: windowing type used in STFT, eg. 'hanning', 'hamming' + """ + super().__init__() + self.feat_dim = fft_len // 2 + 1 + + self.win_len = win_len + self.win_inc = win_inc + self.fft_len = fft_len + self.win_type = win_type + + fix = True + self.stft = ConvSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.istft = ConviSTFT( + self.win_len, + self.win_inc, + self.fft_len, + self.win_type, + feature_type='complex', + fix=fix) + self.unet = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + self.unet2 = UNet( + 1, + complex=complex, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode) + + def forward(self, inputs): + out_list = [] + # [B, D*2, T] + cmp_spec = self.stft(inputs) + # [B, 1, D*2, T] + cmp_spec = torch.unsqueeze(cmp_spec, 1) + + # to [B, 2, D, T] real_part/imag_part + cmp_spec = torch.cat([ + cmp_spec[:, :, :self.feat_dim, :], + cmp_spec[:, :, self.feat_dim:, :], + ], 1) + + # [B, 2, D, T] + cmp_spec = torch.unsqueeze(cmp_spec, 4) + # [B, 1, D, T, 2] + cmp_spec = torch.transpose(cmp_spec, 1, 4) + unet1_out = self.unet(cmp_spec) + cmp_mask1 = torch.tanh(unet1_out) + unet2_out = self.unet2(unet1_out) + cmp_mask2 = torch.tanh(unet2_out) + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + cmp_mask2 = cmp_mask2 + cmp_mask1 + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + return out_list + + def apply_mask(self, cmp_spec, cmp_mask): + est_spec = torch.cat([ + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0] + - cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1], + cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1] + + cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0] + ], 1) + est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1) + cmp_mask = torch.squeeze(cmp_mask, 1) + cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1) + + est_wav = self.istft(est_spec) + est_wav = torch.squeeze(est_wav, 1) + return est_spec, est_wav, cmp_mask + + def get_params(self, weight_decay=0.0): + # add L2 penalty + weights, biases = [], [] + for name, param in self.named_parameters(): + if 'bias' in name: + biases += [param] + else: + weights += [param] + params = [{ + 'params': weights, + 'weight_decay': weight_decay, + }, { + 'params': biases, + 'weight_decay': 0.0, + }] + return params + + def loss(self, noisy, labels, out_list, mode='Mix'): + if mode == 'SiSNR': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + loss = self.loss_1layer(noisy, est_spec, est_wav, labels, + est_mask, mode) + return loss + + elif mode == 'Mix': + count = 0 + while count < len(out_list): + est_spec = out_list[count] + count = count + 1 + est_wav = out_list[count] + count = count + 1 + est_mask = out_list[count] + count = count + 1 + if count != 3: + amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( + noisy, est_spec, est_wav, labels, est_mask, mode) + loss = amp_loss + phase_loss + SiSNR_loss + return loss, amp_loss, phase_loss + + def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): + r""" Compute the loss by mode + mode == 'Mix' + est: [B, F*2, T] + labels: [B, F*2,T] + mode == 'SiSNR' + est: [B, T] + labels: [B, T] + """ + if mode == 'SiSNR': + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + return -si_snr(est_wav, labels) + elif mode == 'Mix': + + if labels.dim() == 3: + labels = torch.squeeze(labels, 1) + if est_wav.dim() == 3: + est_wav = torch.squeeze(est_wav, 1) + SiSNR_loss = -si_snr(est_wav, labels) + + b, d, t = est.size() + S = self.stft(labels) + Sr = S[:, :self.feat_dim, :] + Si = S[:, self.feat_dim:, :] + Y = self.stft(noisy) + Yr = Y[:, :self.feat_dim, :] + Yi = Y[:, self.feat_dim:, :] + Y_pow = Yr**2 + Yi**2 + gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8), + (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1) + gth_mask[gth_mask > 2] = 1 + gth_mask[gth_mask < -2] = -1 + amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :], + cmp_mask[:, :self.feat_dim, :]) * d + phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :], + cmp_mask[:, self.feat_dim:, :]) * d + return amp_loss, phase_loss, SiSNR_loss + + +def l2_norm(s1, s2): + norm = torch.sum(s1 * s2, -1, keepdim=True) + return norm + + +def si_snr(s1, s2, eps=1e-8): + s1_s2_norm = l2_norm(s1, s2) + s2_s2_norm = l2_norm(s2, s2) + s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 + e_nosie = s1 - s_target + target_norm = l2_norm(s_target, s_target) + noise_norm = l2_norm(e_nosie, e_nosie) + snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) + return torch.mean(snr) diff --git a/modelscope/models/audio/ans/se_module_complex.py b/modelscope/models/audio/ans/se_module_complex.py new file mode 100644 index 00000000..f62fe523 --- /dev/null +++ b/modelscope/models/audio/ans/se_module_complex.py @@ -0,0 +1,26 @@ +import torch +from torch import nn + + +class SELayer(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc_r = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + self.fc_i = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _, _ = x.size() + x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) + x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) + y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view( + b, c, 1, 1, 1) + y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view( + b, c, 1, 1, 1) + y = torch.cat([y_r, y_i], 4) + return x * y diff --git a/modelscope/models/audio/ans/unet.py b/modelscope/models/audio/ans/unet.py new file mode 100644 index 00000000..aa5a4254 --- /dev/null +++ b/modelscope/models/audio/ans/unet.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn + +from . import complex_nn +from .se_module_complex import SELayer + + +class Encoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=None, + complex=False, + padding_mode='zeros'): + super().__init__() + if padding is None: + padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding + + if complex: + conv = complex_nn.ComplexConv2d + bn = complex_nn.ComplexBatchNorm2d + else: + conv = nn.Conv2d + bn = nn.BatchNorm2d + + self.conv = conv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=(0, 0), + complex=False): + super().__init__() + if complex: + tconv = complex_nn.ComplexConvTranspose2d + bn = complex_nn.ComplexBatchNorm2d + else: + tconv = nn.ConvTranspose2d + bn = nn.BatchNorm2d + + self.transconv = tconv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.transconv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class UNet(nn.Module): + + def __init__(self, + input_channels=1, + complex=False, + model_complexity=45, + model_depth=20, + padding_mode='zeros'): + super().__init__() + + if complex: + model_complexity = int(model_complexity // 1.414) + + self.set_size( + model_complexity=model_complexity, + input_channels=input_channels, + model_depth=model_depth) + self.encoders = [] + self.model_length = model_depth // 2 + self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128) + self.se_layers_enc = [] + self.fsmn_enc = [] + for i in range(self.model_length): + fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_enc{}'.format(i), fsmn_enc) + self.fsmn_enc.append(fsmn_enc) + module = Encoder( + self.enc_channels[i], + self.enc_channels[i + 1], + kernel_size=self.enc_kernel_sizes[i], + stride=self.enc_strides[i], + padding=self.enc_paddings[i], + complex=complex, + padding_mode=padding_mode) + self.add_module('encoder{}'.format(i), module) + self.encoders.append(module) + se_layer_enc = SELayer(self.enc_channels[i + 1], 8) + self.add_module('se_layer_enc{}'.format(i), se_layer_enc) + self.se_layers_enc.append(se_layer_enc) + self.decoders = [] + self.fsmn_dec = [] + self.se_layers_dec = [] + for i in range(self.model_length): + fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128) + self.add_module('fsmn_dec{}'.format(i), fsmn_dec) + self.fsmn_dec.append(fsmn_dec) + module = Decoder( + self.dec_channels[i] * 2, + self.dec_channels[i + 1], + kernel_size=self.dec_kernel_sizes[i], + stride=self.dec_strides[i], + padding=self.dec_paddings[i], + complex=complex) + self.add_module('decoder{}'.format(i), module) + self.decoders.append(module) + if i < self.model_length - 1: + se_layer_dec = SELayer(self.dec_channels[i + 1], 8) + self.add_module('se_layer_dec{}'.format(i), se_layer_dec) + self.se_layers_dec.append(se_layer_dec) + if complex: + conv = complex_nn.ComplexConv2d + else: + conv = nn.Conv2d + + linear = conv(self.dec_channels[-1], 1, 1) + + self.add_module('linear', linear) + self.complex = complex + self.padding_mode = padding_mode + + self.decoders = nn.ModuleList(self.decoders) + self.encoders = nn.ModuleList(self.encoders) + self.se_layers_enc = nn.ModuleList(self.se_layers_enc) + self.se_layers_dec = nn.ModuleList(self.se_layers_dec) + self.fsmn_enc = nn.ModuleList(self.fsmn_enc) + self.fsmn_dec = nn.ModuleList(self.fsmn_dec) + + def forward(self, inputs): + x = inputs + # go down + xs = [] + xs_se = [] + xs_se.append(x) + for i, encoder in enumerate(self.encoders): + xs.append(x) + if i > 0: + x = self.fsmn_enc[i](x) + x = encoder(x) + xs_se.append(self.se_layers_enc[i](x)) + # xs : x0=input x1 ... x9 + x = self.fsmn(x) + + p = x + for i, decoder in enumerate(self.decoders): + p = decoder(p) + if i < self.model_length - 1: + p = self.fsmn_dec[i](p) + if i == self.model_length - 1: + break + if i < self.model_length - 2: + p = self.se_layers_dec[i](p) + p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1) + + # cmp_spec: [12, 1, 513, 64, 2] + cmp_spec = self.linear(p) + return cmp_spec + + def set_size(self, model_complexity, model_depth=20, input_channels=1): + + if model_depth == 14: + self.enc_channels = [ + input_channels, 128, 128, 128, 128, 128, 128, 128 + ] + self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), + (5, 2), (2, 2)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] + self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), + (5, 2), (5, 2)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), + (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), + (0, 1), (0, 1)] + + elif model_depth == 10: + self.enc_channels = [ + input_channels, + 16, + 32, + 64, + 128, + 256, + ] + self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + self.dec_channels = [128, 128, 64, 32, 16, 1] + self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + elif model_depth == 20: + self.enc_channels = [ + input_channels, model_complexity, model_complexity, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, 128 + ] + + self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), + (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] + + self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1), + (2, 2), (2, 1), (2, 2), (2, 1)] + + self.enc_paddings = [ + (3, 0), + (0, 3), + None, # (0, 2), + None, + None, # (3,1), + None, # (3,1), + None, # (1,2), + None, + None, + None + ] + + self.dec_channels = [ + 0, model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2 + ] + + self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), + (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] + + self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), + (2, 1), (2, 2), (1, 1), (1, 1)] + + self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), + (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] + else: + raise ValueError('Unknown model depth : {}'.format(model_depth)) diff --git a/modelscope/pipelines/__init__.py b/modelscope/pipelines/__init__.py index 14865872..74f5507f 100644 --- a/modelscope/pipelines/__init__.py +++ b/modelscope/pipelines/__init__.py @@ -1,4 +1,5 @@ from .audio import LinearAECPipeline +from .audio.ans_pipeline import ANSPipeline from .base import Pipeline from .builder import pipeline from .cv import * # noqa F403 diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py new file mode 100644 index 00000000..d9a04a29 --- /dev/null +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -0,0 +1,117 @@ +import os.path +from typing import Any, Dict + +import librosa +import numpy as np +import soundfile as sf +import torch + +from modelscope.metainfo import Pipelines +from modelscope.utils.constant import Tasks +from ..base import Input, Pipeline +from ..builder import PIPELINES + + +def audio_norm(x): + rms = (x**2).mean()**0.5 + scalar = 10**(-25 / 20) / rms + x = x * scalar + pow_x = x**2 + avg_pow_x = pow_x.mean() + rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 + scalarx = 10**(-25 / 20) / rmsx + x = x * scalarx + return x + + +@PIPELINES.register_module( + Tasks.speech_signal_process, + module_name=Pipelines.speech_frcrn_ans_cirm_16k) +class ANSPipeline(Pipeline): + r"""ANS (Acoustic Noise Suppression) Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + + def __init__(self, model): + r""" + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + self.model = self.model.to(self.device) + self.model.eval() + + def preprocess(self, inputs: Input) -> Dict[str, Any]: + assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \ + f'Input file do not exists: {inputs}' + data1, fs = sf.read(inputs) + data1 = audio_norm(data1) + if fs != self.SAMPLE_RATE: + data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) + if len(data1.shape) > 1: + data1 = data1[:, 0] + data = data1.astype(np.float32) + inputs = np.reshape(data, [1, data.shape[0]]) + return {'ndarray': inputs, 'nsamples': data.shape[0]} + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + ndarray = inputs['ndarray'] + nsamples = inputs['nsamples'] + decode_do_segement = False + window = 16000 + stride = int(window * 0.75) + print('inputs:{}'.format(ndarray.shape)) + b, t = ndarray.shape # size() + if t > window * 120: + decode_do_segement = True + + if t < window: + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], window - t))], 1) + elif t < window + stride: + padding = window + stride - t + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + else: + if (t - window) % stride != 0: + padding = t - (t - window) // stride * stride + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + print('inputs after padding:{}'.format(ndarray.shape)) + with torch.no_grad(): + ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device) + b, t = ndarray.shape + if decode_do_segement: + outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 + while current_idx + window <= t: + print('current_idx: {}'.format(current_idx)) + tmp_input = ndarray[:, current_idx:current_idx + window] + tmp_output = self.model( + tmp_input, )['wav_l2'][0].cpu().numpy() + end_index = current_idx + window - give_up_length + if current_idx == 0: + outputs[current_idx: + end_index] = tmp_output[:-give_up_length] + else: + outputs[current_idx + + give_up_length:end_index] = tmp_output[ + give_up_length:-give_up_length] + current_idx += stride + else: + outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() + return {'output_pcm': outputs[:nsamples]} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if 'output_path' in kwargs.keys(): + sf.write(kwargs['output_path'], inputs['output_pcm'], + self.SAMPLE_RATE) + return inputs diff --git a/requirements/audio.txt b/requirements/audio.txt index c7b2b239..1f5984ca 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -16,6 +16,7 @@ protobuf>3,<=3.20 ptflops PyWavelets>=1.0.0 scikit-learn +SoundFile>0.10 sox tensorboard tensorflow==1.15.* diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index bc3a542e..f317bc07 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl '?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D' AEC_LIB_FILE = 'libmitaec_pyio.so' +NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' +NOISE_SPEECH_FILE = 'speech_with_noise.wav' + def download(remote_path, local_path): local_dir = os.path.dirname(local_path) @@ -30,23 +33,40 @@ def download(remote_path, local_path): class SpeechSignalProcessTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/speech_dfsmn_aec_psm_16k' + pass + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_aec(self): # A temporary hack to provide c++ lib. Download it first. download(AEC_LIB_URL, AEC_LIB_FILE) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): + # Download audio files download(NEAREND_MIC_URL, NEAREND_MIC_FILE) download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) + model_id = 'damo/speech_dfsmn_aec_psm_16k' input = { 'nearend_mic': NEAREND_MIC_FILE, 'farend_speech': FAREND_SPEECH_FILE } aec = pipeline( Tasks.speech_signal_process, - model=self.model_id, + model=model_id, pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) - aec(input, output_path='output.wav') + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ans(self): + # Download audio files + download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline( + Tasks.speech_signal_process, + model=model_id, + pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) + output_path = os.path.abspath('output.wav') + ans(NOISE_SPEECH_FILE, output_path=output_path) + print(f'Processed audio saved to {output_path}') if __name__ == '__main__':