Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9767151master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:3ad1a268c614076614a2ae6528abc29cc85ae35826d172079d7d9b26a0299559 | |||||
| size 4325096 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:3637ee0628d0953f77d5a32327980af542c43230c4127d2a72b4df1ea2ffb0be | |||||
| size 320042 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:cc116af609a66f431f94df6b385ff2aa362f8a2d437c2279f5401e47f9178469 | |||||
| size 320042 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:9354345a6297f4522e690d337546aa9a686a7e61eefcd935478a2141b924db8f | |||||
| size 76770 | |||||
| @@ -38,6 +38,7 @@ class Models(object): | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| generic_asr = 'generic-asr' | generic_asr = 'generic-asr' | ||||
| @@ -133,6 +134,7 @@ class Pipelines(object): | |||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||||
| kws_kwsbp = 'kws-kwsbp' | kws_kwsbp = 'kws-kwsbp' | ||||
| asr_inference = 'asr-inference' | asr_inference = 'asr-inference' | ||||
| @@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .generic_key_word_spotting import GenericKeyWordSpotting | from .generic_key_word_spotting import GenericKeyWordSpotting | ||||
| from .farfield.model import FSMNSeleNetV2Decorator | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'generic_key_word_spotting': ['GenericKeyWordSpotting'], | 'generic_key_word_spotting': ['GenericKeyWordSpotting'], | ||||
| 'farfield.model': ['FSMNSeleNetV2Decorator'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,495 @@ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32, | |||||
| printNeonMatrix, printNeonVector) | |||||
| DEBUG = False | |||||
| def to_kaldi_matrix(np_mat): | |||||
| """ function that transform as str numpy mat to standard kaldi str matrix | |||||
| Args: | |||||
| np_mat: numpy mat | |||||
| Returns: str | |||||
| """ | |||||
| np.set_printoptions(threshold=np.inf, linewidth=np.nan) | |||||
| out_str = str(np_mat) | |||||
| out_str = out_str.replace('[', '') | |||||
| out_str = out_str.replace(']', '') | |||||
| return '[ %s ]\n' % out_str | |||||
| def print_tensor(torch_tensor): | |||||
| """ print torch tensor for debug | |||||
| Args: | |||||
| torch_tensor: a tensor | |||||
| """ | |||||
| re_str = '' | |||||
| x = torch_tensor.detach().squeeze().numpy() | |||||
| re_str += to_kaldi_matrix(x) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| print(re_str) | |||||
| class LinearTransform(nn.Module): | |||||
| def __init__(self, input_dim, output_dim): | |||||
| super(LinearTransform, self).__init__() | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.linear = nn.Linear(input_dim, output_dim, bias=False) | |||||
| self.debug = False | |||||
| self.dataout = None | |||||
| def forward(self, input): | |||||
| output = self.linear(input) | |||||
| if self.debug: | |||||
| self.dataout = output | |||||
| return output | |||||
| def print_model(self): | |||||
| printNeonMatrix(self.linear.weight) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '' | |||||
| re_str += '<LinearTransform> %d %d\n' % (self.output_dim, | |||||
| self.input_dim) | |||||
| re_str += '<LearnRateCoef> 1\n' | |||||
| linear_weights = self.state_dict()['linear.weight'] | |||||
| x = linear_weights.squeeze().numpy() | |||||
| re_str += to_kaldi_matrix(x) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| return re_str | |||||
| class AffineTransform(nn.Module): | |||||
| def __init__(self, input_dim, output_dim): | |||||
| super(AffineTransform, self).__init__() | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.linear = nn.Linear(input_dim, output_dim) | |||||
| self.debug = False | |||||
| self.dataout = None | |||||
| def forward(self, input): | |||||
| output = self.linear(input) | |||||
| if self.debug: | |||||
| self.dataout = output | |||||
| return output | |||||
| def print_model(self): | |||||
| printNeonMatrix(self.linear.weight) | |||||
| printNeonVector(self.linear.bias) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '' | |||||
| re_str += '<AffineTransform> %d %d\n' % (self.output_dim, | |||||
| self.input_dim) | |||||
| re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n' | |||||
| linear_weights = self.state_dict()['linear.weight'] | |||||
| x = linear_weights.squeeze().numpy() | |||||
| re_str += to_kaldi_matrix(x) | |||||
| linear_bias = self.state_dict()['linear.bias'] | |||||
| x = linear_bias.squeeze().numpy() | |||||
| re_str += to_kaldi_matrix(x) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| return re_str | |||||
| class Fsmn(nn.Module): | |||||
| """ | |||||
| FSMN implementation. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_dim, | |||||
| output_dim, | |||||
| lorder=None, | |||||
| rorder=None, | |||||
| lstride=None, | |||||
| rstride=None): | |||||
| super(Fsmn, self).__init__() | |||||
| self.dim = input_dim | |||||
| if lorder is None: | |||||
| return | |||||
| self.lorder = lorder | |||||
| self.rorder = rorder | |||||
| self.lstride = lstride | |||||
| self.rstride = rstride | |||||
| self.conv_left = nn.Conv2d( | |||||
| self.dim, | |||||
| self.dim, (lorder, 1), | |||||
| dilation=(lstride, 1), | |||||
| groups=self.dim, | |||||
| bias=False) | |||||
| if rorder > 0: | |||||
| self.conv_right = nn.Conv2d( | |||||
| self.dim, | |||||
| self.dim, (rorder, 1), | |||||
| dilation=(rstride, 1), | |||||
| groups=self.dim, | |||||
| bias=False) | |||||
| else: | |||||
| self.conv_right = None | |||||
| self.debug = False | |||||
| self.dataout = None | |||||
| def forward(self, input): | |||||
| x = torch.unsqueeze(input, 1) | |||||
| x_per = x.permute(0, 3, 2, 1) | |||||
| y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) | |||||
| if self.conv_right is not None: | |||||
| y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) | |||||
| y_right = y_right[:, :, self.rstride:, :] | |||||
| out = x_per + self.conv_left(y_left) + self.conv_right(y_right) | |||||
| else: | |||||
| out = x_per + self.conv_left(y_left) | |||||
| out1 = out.permute(0, 3, 2, 1) | |||||
| output = out1.squeeze(1) | |||||
| if self.debug: | |||||
| self.dataout = output | |||||
| return output | |||||
| def print_model(self): | |||||
| tmpw = self.conv_left.weight | |||||
| tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||||
| for j in range(tmpw.shape[0]): | |||||
| tmpwm[:, j] = tmpw[j, 0, :, 0] | |||||
| printNeonMatrix(tmpwm) | |||||
| if self.conv_right is not None: | |||||
| tmpw = self.conv_right.weight | |||||
| tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||||
| for j in range(tmpw.shape[0]): | |||||
| tmpwm[:, j] = tmpw[j, 0, :, 0] | |||||
| printNeonMatrix(tmpwm) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '' | |||||
| re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim) | |||||
| re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % ( | |||||
| 1, self.lorder, self.rorder, self.lstride, self.rstride) | |||||
| lfiters = self.state_dict()['conv_left.weight'] | |||||
| x = np.flipud(lfiters.squeeze().numpy().T) | |||||
| re_str += to_kaldi_matrix(x) | |||||
| if self.conv_right is not None: | |||||
| rfiters = self.state_dict()['conv_right.weight'] | |||||
| x = (rfiters.squeeze().numpy().T) | |||||
| re_str += to_kaldi_matrix(x) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| return re_str | |||||
| class RectifiedLinear(nn.Module): | |||||
| def __init__(self, input_dim, output_dim): | |||||
| super(RectifiedLinear, self).__init__() | |||||
| self.dim = input_dim | |||||
| self.relu = nn.ReLU() | |||||
| def forward(self, input): | |||||
| return self.relu(input) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '' | |||||
| re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| return re_str | |||||
| class FSMNNet(nn.Module): | |||||
| """ | |||||
| FSMN net for keyword spotting | |||||
| """ | |||||
| def __init__(self, | |||||
| input_dim=200, | |||||
| linear_dim=128, | |||||
| proj_dim=128, | |||||
| lorder=10, | |||||
| rorder=1, | |||||
| num_syn=5, | |||||
| fsmn_layers=4): | |||||
| """ | |||||
| Args: | |||||
| input_dim: input dimension | |||||
| linear_dim: fsmn input dimension | |||||
| proj_dim: fsmn projection dimension | |||||
| lorder: fsmn left order | |||||
| rorder: fsmn right order | |||||
| num_syn: output dimension | |||||
| fsmn_layers: no. of sequential fsmn layers | |||||
| """ | |||||
| super(FSMNNet, self).__init__() | |||||
| self.input_dim = input_dim | |||||
| self.linear_dim = linear_dim | |||||
| self.proj_dim = proj_dim | |||||
| self.lorder = lorder | |||||
| self.rorder = rorder | |||||
| self.num_syn = num_syn | |||||
| self.fsmn_layers = fsmn_layers | |||||
| self.linear1 = AffineTransform(input_dim, linear_dim) | |||||
| self.relu = RectifiedLinear(linear_dim, linear_dim) | |||||
| self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder, | |||||
| fsmn_layers) | |||||
| self.linear2 = AffineTransform(linear_dim, num_syn) | |||||
| @staticmethod | |||||
| def _build_repeats(linear_dim=136, | |||||
| proj_dim=68, | |||||
| lorder=3, | |||||
| rorder=2, | |||||
| fsmn_layers=5): | |||||
| repeats = [ | |||||
| nn.Sequential( | |||||
| LinearTransform(linear_dim, proj_dim), | |||||
| Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1), | |||||
| AffineTransform(proj_dim, linear_dim), | |||||
| RectifiedLinear(linear_dim, linear_dim)) | |||||
| for i in range(fsmn_layers) | |||||
| ] | |||||
| return nn.Sequential(*repeats) | |||||
| def forward(self, input): | |||||
| x1 = self.linear1(input) | |||||
| x2 = self.relu(x1) | |||||
| x3 = self.fsmn(x2) | |||||
| x4 = self.linear2(x3) | |||||
| return x4 | |||||
| def print_model(self): | |||||
| self.linear1.print_model() | |||||
| for layer in self.fsmn: | |||||
| layer[0].print_model() | |||||
| layer[1].print_model() | |||||
| layer[2].print_model() | |||||
| self.linear2.print_model() | |||||
| def print_header(self): | |||||
| # | |||||
| # write total header | |||||
| # | |||||
| header = [0.0] * HEADER_BLOCK_SIZE * 4 | |||||
| # numins | |||||
| header[0] = 0.0 | |||||
| # numouts | |||||
| header[1] = 0.0 | |||||
| # dimins | |||||
| header[2] = self.input_dim | |||||
| # dimouts | |||||
| header[3] = self.num_syn | |||||
| # numlayers | |||||
| header[4] = 3 | |||||
| # | |||||
| # write each layer's header | |||||
| # | |||||
| hidx = 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_DENSE.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||||
| ActivationType.ACTIVATION_RELU.value) | |||||
| hidx += 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_SEQUENTIAL_FSMN.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder | |||||
| header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers | |||||
| header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 | |||||
| hidx += 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_DENSE.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||||
| ActivationType.ACTIVATION_SOFTMAX.value) | |||||
| for h in header: | |||||
| print(f32ToI32(h)) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '' | |||||
| re_str += '<Nnet>\n' | |||||
| re_str += self.linear1.to_kaldi_nnet() | |||||
| re_str += self.relu.to_kaldi_nnet() | |||||
| for fsmn in self.fsmn: | |||||
| re_str += fsmn[0].to_kaldi_nnet() | |||||
| re_str += fsmn[1].to_kaldi_nnet() | |||||
| re_str += fsmn[2].to_kaldi_nnet() | |||||
| re_str += fsmn[3].to_kaldi_nnet() | |||||
| re_str += self.linear2.to_kaldi_nnet() | |||||
| re_str += '<Softmax> %d %d\n' % (self.num_syn, self.num_syn) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| re_str += '</Nnet>\n' | |||||
| return re_str | |||||
| class DFSMN(nn.Module): | |||||
| """ | |||||
| One deep fsmn layer | |||||
| """ | |||||
| def __init__(self, | |||||
| dimproj=64, | |||||
| dimlinear=128, | |||||
| lorder=20, | |||||
| rorder=1, | |||||
| lstride=1, | |||||
| rstride=1): | |||||
| """ | |||||
| Args: | |||||
| dimproj: projection dimension, input and output dimension of memory blocks | |||||
| dimlinear: dimension of mapping layer | |||||
| lorder: left order | |||||
| rorder: right order | |||||
| lstride: left stride | |||||
| rstride: right stride | |||||
| """ | |||||
| super(DFSMN, self).__init__() | |||||
| self.lorder = lorder | |||||
| self.rorder = rorder | |||||
| self.lstride = lstride | |||||
| self.rstride = rstride | |||||
| self.expand = AffineTransform(dimproj, dimlinear) | |||||
| self.shrink = LinearTransform(dimlinear, dimproj) | |||||
| self.conv_left = nn.Conv2d( | |||||
| dimproj, | |||||
| dimproj, (lorder, 1), | |||||
| dilation=(lstride, 1), | |||||
| groups=dimproj, | |||||
| bias=False) | |||||
| if rorder > 0: | |||||
| self.conv_right = nn.Conv2d( | |||||
| dimproj, | |||||
| dimproj, (rorder, 1), | |||||
| dilation=(rstride, 1), | |||||
| groups=dimproj, | |||||
| bias=False) | |||||
| else: | |||||
| self.conv_right = None | |||||
| def forward(self, input): | |||||
| f1 = F.relu(self.expand(input)) | |||||
| p1 = self.shrink(f1) | |||||
| x = torch.unsqueeze(p1, 1) | |||||
| x_per = x.permute(0, 3, 2, 1) | |||||
| y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) | |||||
| if self.conv_right is not None: | |||||
| y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) | |||||
| y_right = y_right[:, :, self.rstride:, :] | |||||
| out = x_per + self.conv_left(y_left) + self.conv_right(y_right) | |||||
| else: | |||||
| out = x_per + self.conv_left(y_left) | |||||
| out1 = out.permute(0, 3, 2, 1) | |||||
| output = input + out1.squeeze(1) | |||||
| return output | |||||
| def print_model(self): | |||||
| self.expand.print_model() | |||||
| self.shrink.print_model() | |||||
| tmpw = self.conv_left.weight | |||||
| tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||||
| for j in range(tmpw.shape[0]): | |||||
| tmpwm[:, j] = tmpw[j, 0, :, 0] | |||||
| printNeonMatrix(tmpwm) | |||||
| if self.conv_right is not None: | |||||
| tmpw = self.conv_right.weight | |||||
| tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0]) | |||||
| for j in range(tmpw.shape[0]): | |||||
| tmpwm[:, j] = tmpw[j, 0, :, 0] | |||||
| printNeonMatrix(tmpwm) | |||||
| def build_dfsmn_repeats(linear_dim=128, | |||||
| proj_dim=64, | |||||
| lorder=20, | |||||
| rorder=1, | |||||
| fsmn_layers=6): | |||||
| """ | |||||
| build stacked dfsmn layers | |||||
| Args: | |||||
| linear_dim: | |||||
| proj_dim: | |||||
| lorder: | |||||
| rorder: | |||||
| fsmn_layers: | |||||
| Returns: | |||||
| """ | |||||
| repeats = [ | |||||
| nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) | |||||
| for i in range(fsmn_layers) | |||||
| ] | |||||
| return nn.Sequential(*repeats) | |||||
| @@ -0,0 +1,236 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .fsmn import AffineTransform, Fsmn, LinearTransform, RectifiedLinear | |||||
| from .model_def import HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32 | |||||
| class FSMNUnit(nn.Module): | |||||
| """ A multi-channel fsmn unit | |||||
| """ | |||||
| def __init__(self, dimlinear=128, dimproj=64, lorder=20, rorder=1): | |||||
| """ | |||||
| Args: | |||||
| dimlinear: input / output dimension | |||||
| dimproj: fsmn input / output dimension | |||||
| lorder: left ofder | |||||
| rorder: right order | |||||
| """ | |||||
| super(FSMNUnit, self).__init__() | |||||
| self.shrink = LinearTransform(dimlinear, dimproj) | |||||
| self.fsmn = Fsmn(dimproj, dimproj, lorder, rorder, 1, 1) | |||||
| self.expand = AffineTransform(dimproj, dimlinear) | |||||
| self.debug = False | |||||
| self.dataout = None | |||||
| ''' | |||||
| batch, time, channel, feature | |||||
| ''' | |||||
| def forward(self, input): | |||||
| if torch.cuda.is_available(): | |||||
| out = torch.zeros(input.shape).cuda() | |||||
| else: | |||||
| out = torch.zeros(input.shape) | |||||
| for n in range(input.shape[2]): | |||||
| out1 = self.shrink(input[:, :, n, :]) | |||||
| out2 = self.fsmn(out1) | |||||
| out[:, :, n, :] = F.relu(self.expand(out2)) | |||||
| if self.debug: | |||||
| self.dataout = out | |||||
| return out | |||||
| def print_model(self): | |||||
| self.shrink.print_model() | |||||
| self.fsmn.print_model() | |||||
| self.expand.print_model() | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = self.shrink.to_kaldi_nnet() | |||||
| re_str += self.fsmn.to_kaldi_nnet() | |||||
| re_str += self.expand.to_kaldi_nnet() | |||||
| relu = RectifiedLinear(self.expand.linear.out_features, | |||||
| self.expand.linear.out_features) | |||||
| re_str += relu.to_kaldi_nnet() | |||||
| return re_str | |||||
| class FSMNSeleNetV2(nn.Module): | |||||
| """ FSMN model with channel selection. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_dim=120, | |||||
| linear_dim=128, | |||||
| proj_dim=64, | |||||
| lorder=20, | |||||
| rorder=1, | |||||
| num_syn=5, | |||||
| fsmn_layers=5, | |||||
| sele_layer=0): | |||||
| """ | |||||
| Args: | |||||
| input_dim: input dimension | |||||
| linear_dim: fsmn input dimension | |||||
| proj_dim: fsmn projection dimension | |||||
| lorder: fsmn left order | |||||
| rorder: fsmn right order | |||||
| num_syn: output dimension | |||||
| fsmn_layers: no. of fsmn units | |||||
| sele_layer: channel selection layer index | |||||
| """ | |||||
| super(FSMNSeleNetV2, self).__init__() | |||||
| self.sele_layer = sele_layer | |||||
| self.featmap = AffineTransform(input_dim, linear_dim) | |||||
| self.mem = [] | |||||
| for i in range(fsmn_layers): | |||||
| unit = FSMNUnit(linear_dim, proj_dim, lorder, rorder) | |||||
| self.mem.append(unit) | |||||
| self.add_module('mem_{:d}'.format(i), unit) | |||||
| self.decision = AffineTransform(linear_dim, num_syn) | |||||
| def forward(self, input): | |||||
| # multi-channel feature mapping | |||||
| if torch.cuda.is_available(): | |||||
| x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], | |||||
| self.featmap.linear.out_features).cuda() | |||||
| else: | |||||
| x = torch.zeros(input.shape[0], input.shape[1], input.shape[2], | |||||
| self.featmap.linear.out_features) | |||||
| for n in range(input.shape[2]): | |||||
| x[:, :, n, :] = F.relu(self.featmap(input[:, :, n, :])) | |||||
| for i, unit in enumerate(self.mem): | |||||
| y = unit(x) | |||||
| # perform channel selection | |||||
| if i == self.sele_layer: | |||||
| pool = nn.MaxPool2d((y.shape[2], 1), stride=(y.shape[2], 1)) | |||||
| y = pool(y) | |||||
| x = y | |||||
| # remove channel dimension | |||||
| y = torch.squeeze(y, -2) | |||||
| z = self.decision(y) | |||||
| return z | |||||
| def print_model(self): | |||||
| self.featmap.print_model() | |||||
| for unit in self.mem: | |||||
| unit.print_model() | |||||
| self.decision.print_model() | |||||
| def print_header(self): | |||||
| ''' | |||||
| get FSMN params | |||||
| ''' | |||||
| input_dim = self.featmap.linear.in_features | |||||
| linear_dim = self.featmap.linear.out_features | |||||
| proj_dim = self.mem[0].shrink.linear.out_features | |||||
| lorder = self.mem[0].fsmn.conv_left.kernel_size[0] | |||||
| rorder = 0 | |||||
| if self.mem[0].fsmn.conv_right is not None: | |||||
| rorder = self.mem[0].fsmn.conv_right.kernel_size[0] | |||||
| num_syn = self.decision.linear.out_features | |||||
| fsmn_layers = len(self.mem) | |||||
| # no. of output channels, 0.0 means the same as numins | |||||
| # numouts = 0.0 | |||||
| numouts = 1.0 | |||||
| # | |||||
| # write total header | |||||
| # | |||||
| header = [0.0] * HEADER_BLOCK_SIZE * 4 | |||||
| # numins | |||||
| header[0] = 0.0 | |||||
| # numouts | |||||
| header[1] = numouts | |||||
| # dimins | |||||
| header[2] = input_dim | |||||
| # dimouts | |||||
| header[3] = num_syn | |||||
| # numlayers | |||||
| header[4] = 3 | |||||
| # | |||||
| # write each layer's header | |||||
| # | |||||
| hidx = 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_DENSE.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = input_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||||
| ActivationType.ACTIVATION_RELU.value) | |||||
| hidx += 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_SEQUENTIAL_FSMN.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = proj_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = lorder | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = rorder | |||||
| header[HEADER_BLOCK_SIZE * hidx + 6] = fsmn_layers | |||||
| if numouts == 1.0: | |||||
| header[HEADER_BLOCK_SIZE * hidx + 7] = float(self.sele_layer) | |||||
| else: | |||||
| header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0 | |||||
| hidx += 1 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 0] = float( | |||||
| LayerType.LAYER_DENSE.value) | |||||
| header[HEADER_BLOCK_SIZE * hidx + 1] = numouts | |||||
| header[HEADER_BLOCK_SIZE * hidx + 2] = linear_dim | |||||
| header[HEADER_BLOCK_SIZE * hidx + 3] = num_syn | |||||
| header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0 | |||||
| header[HEADER_BLOCK_SIZE * hidx + 5] = float( | |||||
| ActivationType.ACTIVATION_SOFTMAX.value) | |||||
| for h in header: | |||||
| print(f32ToI32(h)) | |||||
| def to_kaldi_nnet(self): | |||||
| re_str = '<Nnet>\n' | |||||
| re_str = self.featmap.to_kaldi_nnet() | |||||
| relu = RectifiedLinear(self.featmap.linear.out_features, | |||||
| self.featmap.linear.out_features) | |||||
| re_str += relu.to_kaldi_nnet() | |||||
| for unit in self.mem: | |||||
| re_str += unit.to_kaldi_nnet() | |||||
| re_str += self.decision.to_kaldi_nnet() | |||||
| re_str += '<Softmax> %d %d\n' % (self.decision.linear.out_features, | |||||
| self.decision.linear.out_features) | |||||
| re_str += '<!EndOfComponent>\n' | |||||
| re_str += '</Nnet>\n' | |||||
| return re_str | |||||
| @@ -0,0 +1,74 @@ | |||||
| import os | |||||
| from typing import Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.base import Tensor | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .fsmn_sele_v2 import FSMNSeleNetV2 | |||||
| @MODELS.register_module( | |||||
| Tasks.keyword_spotting, module_name=Models.speech_dfsmn_kws_char_farfield) | |||||
| class FSMNSeleNetV2Decorator(TorchModel): | |||||
| r""" A decorator of FSMNSeleNetV2 for integrating into modelscope framework """ | |||||
| MODEL_TXT = 'model.txt' | |||||
| SC_CONFIG = 'sound_connect.conf' | |||||
| SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the dfsmn model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
| model_bin_file = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self._model = None | |||||
| if os.path.exists(model_bin_file): | |||||
| self._model = FSMNSeleNetV2(*args, **kwargs) | |||||
| checkpoint = torch.load(model_bin_file) | |||||
| self._model.load_state_dict(checkpoint, strict=False) | |||||
| self._sc = None | |||||
| if os.path.exists(model_txt_file): | |||||
| with open(sc_config_file) as f: | |||||
| lines = f.readlines() | |||||
| with open(sc_config_file, 'w') as f: | |||||
| for line in lines: | |||||
| if self.SC_CONF_ITEM_KWS_MODEL in line: | |||||
| line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||||
| model_txt_file) | |||||
| f.write(line) | |||||
| import py_sound_connect | |||||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
| self.size_in = self._sc.bytesPerBlockIn() | |||||
| self.size_out = self._sc.bytesPerBlockOut() | |||||
| if self._model is None and self._sc is None: | |||||
| raise Exception( | |||||
| f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||||
| ) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| ... | |||||
| def forward_decode(self, data: bytes): | |||||
| result = {'pcm': self._sc.process(data, self.size_out)} | |||||
| state = self._sc.kwsState() | |||||
| if state == 2: | |||||
| result['kws'] = { | |||||
| 'keyword': | |||||
| self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()), | |||||
| 'offset': self._sc.kwsKeywordOffset(), | |||||
| 'length': self._sc.kwsKeywordLength(), | |||||
| 'confidence': self._sc.kwsConfidence() | |||||
| } | |||||
| return result | |||||
| @@ -0,0 +1,121 @@ | |||||
| import math | |||||
| import struct | |||||
| from enum import Enum | |||||
| HEADER_BLOCK_SIZE = 10 | |||||
| class LayerType(Enum): | |||||
| LAYER_DENSE = 1 | |||||
| LAYER_GRU = 2 | |||||
| LAYER_ATTENTION = 3 | |||||
| LAYER_FSMN = 4 | |||||
| LAYER_SEQUENTIAL_FSMN = 5 | |||||
| LAYER_FSMN_SELE = 6 | |||||
| LAYER_GRU_ATTENTION = 7 | |||||
| LAYER_DFSMN = 8 | |||||
| class ActivationType(Enum): | |||||
| ACTIVATION_NONE = 0 | |||||
| ACTIVATION_RELU = 1 | |||||
| ACTIVATION_TANH = 2 | |||||
| ACTIVATION_SIGMOID = 3 | |||||
| ACTIVATION_SOFTMAX = 4 | |||||
| ACTIVATION_LOGSOFTMAX = 5 | |||||
| def f32ToI32(f): | |||||
| """ | |||||
| print layer | |||||
| """ | |||||
| bs = struct.pack('f', f) | |||||
| ba = bytearray() | |||||
| ba.append(bs[0]) | |||||
| ba.append(bs[1]) | |||||
| ba.append(bs[2]) | |||||
| ba.append(bs[3]) | |||||
| return struct.unpack('i', ba)[0] | |||||
| def printNeonMatrix(w): | |||||
| """ | |||||
| print matrix with neon padding | |||||
| """ | |||||
| numrows, numcols = w.shape | |||||
| numnecols = math.ceil(numcols / 4) | |||||
| for i in range(numrows): | |||||
| for j in range(numcols): | |||||
| print(f32ToI32(w[i, j])) | |||||
| for j in range(numnecols * 4 - numcols): | |||||
| print(0) | |||||
| def printNeonVector(b): | |||||
| """ | |||||
| print vector with neon padding | |||||
| """ | |||||
| size = b.shape[0] | |||||
| nesize = math.ceil(size / 4) | |||||
| for i in range(size): | |||||
| print(f32ToI32(b[i])) | |||||
| for i in range(nesize * 4 - size): | |||||
| print(0) | |||||
| def printDense(layer): | |||||
| """ | |||||
| save dense layer | |||||
| """ | |||||
| statedict = layer.state_dict() | |||||
| printNeonMatrix(statedict['weight']) | |||||
| printNeonVector(statedict['bias']) | |||||
| def printGRU(layer): | |||||
| """ | |||||
| save gru layer | |||||
| """ | |||||
| statedict = layer.state_dict() | |||||
| weight = [statedict['weight_ih_l0'], statedict['weight_hh_l0']] | |||||
| bias = [statedict['bias_ih_l0'], statedict['bias_hh_l0']] | |||||
| numins, numouts = weight[0].shape | |||||
| numins = numins // 3 | |||||
| # output input weights | |||||
| w_rx = weight[0][:numins, :] | |||||
| w_zx = weight[0][numins:numins * 2, :] | |||||
| w_x = weight[0][numins * 2:, :] | |||||
| printNeonMatrix(w_zx) | |||||
| printNeonMatrix(w_rx) | |||||
| printNeonMatrix(w_x) | |||||
| # output recurrent weights | |||||
| w_rh = weight[1][:numins, :] | |||||
| w_zh = weight[1][numins:numins * 2, :] | |||||
| w_h = weight[1][numins * 2:, :] | |||||
| printNeonMatrix(w_zh) | |||||
| printNeonMatrix(w_rh) | |||||
| printNeonMatrix(w_h) | |||||
| # output input bias | |||||
| b_rx = bias[0][:numins] | |||||
| b_zx = bias[0][numins:numins * 2] | |||||
| b_x = bias[0][numins * 2:] | |||||
| printNeonVector(b_zx) | |||||
| printNeonVector(b_rx) | |||||
| printNeonVector(b_x) | |||||
| # output recurrent bias | |||||
| b_rh = bias[1][:numins] | |||||
| b_zh = bias[1][numins:numins * 2] | |||||
| b_h = bias[1][numins * 2:] | |||||
| printNeonVector(b_zh) | |||||
| printNeonVector(b_rh) | |||||
| printNeonVector(b_h) | |||||
| @@ -405,7 +405,7 @@ TASK_OUTPUTS = { | |||||
| # audio processed for single file in PCM format | # audio processed for single file in PCM format | ||||
| # { | # { | ||||
| # "output_pcm": np.array with shape(samples,) and dtype float32 | |||||
| # "output_pcm": pcm encoded audio bytes | |||||
| # } | # } | ||||
| Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], | Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], | ||||
| Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], | Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], | ||||
| @@ -417,6 +417,19 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], | Tasks.text_to_speech: [OutputKeys.OUTPUT_PCM], | ||||
| # { | |||||
| # "kws_list": [ | |||||
| # { | |||||
| # 'keyword': '', # the keyword spotted | |||||
| # 'offset': 19.4, # the keyword start time in second | |||||
| # 'length': 0.68, # the keyword length in second | |||||
| # 'confidence': 0.85 # the possibility if it is the keyword | |||||
| # }, | |||||
| # ... | |||||
| # ] | |||||
| # } | |||||
| Tasks.keyword_spotting: [OutputKeys.KWS_LIST], | |||||
| # ============ multi-modal tasks =================== | # ============ multi-modal tasks =================== | ||||
| # image caption result for single sample | # image caption result for single sample | ||||
| @@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .ans_pipeline import ANSPipeline | from .ans_pipeline import ANSPipeline | ||||
| from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | from .asr_inference_pipeline import AutomaticSpeechRecognitionPipeline | ||||
| from .kws_farfield_pipeline import KWSFarfieldPipeline | |||||
| from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline | from .kws_kwsbp_pipeline import KeyWordSpottingKwsbpPipeline | ||||
| from .linear_aec_pipeline import LinearAECPipeline | from .linear_aec_pipeline import LinearAECPipeline | ||||
| from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline | from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline | ||||
| @@ -14,6 +15,7 @@ else: | |||||
| _import_structure = { | _import_structure = { | ||||
| 'ans_pipeline': ['ANSPipeline'], | 'ans_pipeline': ['ANSPipeline'], | ||||
| 'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], | 'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'], | ||||
| 'kws_farfield_pipeline': ['KWSFarfieldPipeline'], | |||||
| 'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'], | 'kws_kwsbp_pipeline': ['KeyWordSpottingKwsbpPipeline'], | ||||
| 'linear_aec_pipeline': ['LinearAECPipeline'], | 'linear_aec_pipeline': ['LinearAECPipeline'], | ||||
| 'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'], | 'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'], | ||||
| @@ -0,0 +1,81 @@ | |||||
| import io | |||||
| import wave | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| Tasks.keyword_spotting, | |||||
| module_name=Pipelines.speech_dfsmn_kws_char_farfield) | |||||
| class KWSFarfieldPipeline(Pipeline): | |||||
| r"""A Keyword Spotting Inference Pipeline . | |||||
| When invoke the class with pipeline.__call__(), it accept only one parameter: | |||||
| inputs(str): the path of wav file | |||||
| """ | |||||
| SAMPLE_RATE = 16000 | |||||
| SAMPLE_WIDTH = 2 | |||||
| INPUT_CHANNELS = 3 | |||||
| OUTPUT_CHANNELS = 2 | |||||
| def __init__(self, model, **kwargs): | |||||
| """ | |||||
| use `model` to create a kws far field pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| self.model = self.model.to(self.device) | |||||
| self.model.eval() | |||||
| frame_size = self.INPUT_CHANNELS * self.SAMPLE_WIDTH | |||||
| self._nframe = self.model.size_in // frame_size | |||||
| self.frame_count = 0 | |||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||||
| if isinstance(inputs, bytes): | |||||
| return dict(input_file=inputs) | |||||
| elif isinstance(inputs, Dict): | |||||
| return inputs | |||||
| else: | |||||
| raise ValueError(f'Not supported input type: {type(inputs)}') | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| input_file = inputs['input_file'] | |||||
| if isinstance(input_file, bytes): | |||||
| input_file = io.BytesIO(input_file) | |||||
| self.frame_count = 0 | |||||
| kws_list = [] | |||||
| with wave.open(input_file, 'rb') as fin: | |||||
| if 'output_file' in inputs: | |||||
| with wave.open(inputs['output_file'], 'wb') as fout: | |||||
| fout.setframerate(self.SAMPLE_RATE) | |||||
| fout.setnchannels(self.OUTPUT_CHANNELS) | |||||
| fout.setsampwidth(self.SAMPLE_WIDTH) | |||||
| self._process(fin, kws_list, fout) | |||||
| else: | |||||
| self._process(fin, kws_list) | |||||
| return {OutputKeys.KWS_LIST: kws_list} | |||||
| def _process(self, | |||||
| fin: wave.Wave_read, | |||||
| kws_list, | |||||
| fout: wave.Wave_write = None): | |||||
| data = fin.readframes(self._nframe) | |||||
| while len(data) >= self.model.size_in: | |||||
| self.frame_count += self._nframe | |||||
| result = self.model.forward_decode(data) | |||||
| if fout: | |||||
| fout.writeframes(result['pcm']) | |||||
| if 'kws' in result: | |||||
| result['kws']['offset'] += self.frame_count / self.SAMPLE_RATE | |||||
| kws_list.append(result['kws']) | |||||
| data = fin.readframes(self._nframe) | |||||
| def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -255,7 +255,7 @@ class Pipeline(ABC): | |||||
| return self._collate_fn(torch.from_numpy(data)) | return self._collate_fn(torch.from_numpy(data)) | ||||
| elif isinstance(data, torch.Tensor): | elif isinstance(data, torch.Tensor): | ||||
| return data.to(self.device) | return data.to(self.device) | ||||
| elif isinstance(data, (str, int, float, bool, type(None))): | |||||
| elif isinstance(data, (bytes, str, int, float, bool, type(None))): | |||||
| return data | return data | ||||
| elif isinstance(data, InputFeatures): | elif isinstance(data, InputFeatures): | ||||
| return data | return data | ||||
| @@ -16,6 +16,7 @@ numpy<=1.18 | |||||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | ||||
| protobuf>3,<3.21.0 | protobuf>3,<3.21.0 | ||||
| ptflops | ptflops | ||||
| py_sound_connect | |||||
| pytorch_wavelets | pytorch_wavelets | ||||
| PyWavelets>=1.0.0 | PyWavelets>=1.0.0 | ||||
| scikit-learn | scikit-learn | ||||
| @@ -0,0 +1,43 @@ | |||||
| import os.path | |||||
| import unittest | |||||
| from modelscope.fileio import File | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' | |||||
| class KWSFarfieldTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_normal(self): | |||||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||||
| inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} | |||||
| result = kws(inputs) | |||||
| self.assertEqual(len(result['kws_list']), 5) | |||||
| print(result['kws_list'][-1]) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_output(self): | |||||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||||
| inputs = { | |||||
| 'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE), | |||||
| 'output_file': 'output.wav' | |||||
| } | |||||
| result = kws(inputs) | |||||
| self.assertEqual(len(result['kws_list']), 5) | |||||
| print(result['kws_list'][-1]) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_input_bytes(self): | |||||
| with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f: | |||||
| data = f.read() | |||||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||||
| result = kws(data) | |||||
| self.assertEqual(len(result['kws_list']), 5) | |||||
| print(result['kws_list'][-1]) | |||||
| @@ -8,22 +8,10 @@ from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| NEAREND_MIC_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/nearend_mic.wav' | |||||
| FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/sample_audio/farend_speech.wav' | |||||
| NEAREND_MIC_FILE = 'nearend_mic.wav' | |||||
| FAREND_SPEECH_FILE = 'farend_speech.wav' | |||||
| NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' | |||||
| FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' | |||||
| NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav' | |||||
| NOISE_SPEECH_FILE = 'speech_with_noise.wav' | |||||
| def download(remote_path, local_path): | |||||
| local_dir = os.path.dirname(local_path) | |||||
| if len(local_dir) > 0: | |||||
| if not os.path.exists(local_dir): | |||||
| os.makedirs(local_dir) | |||||
| with open(local_path, 'wb') as ofile: | |||||
| ofile.write(File.read(remote_path)) | |||||
| NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
| class SpeechSignalProcessTest(unittest.TestCase): | class SpeechSignalProcessTest(unittest.TestCase): | ||||
| @@ -33,13 +21,10 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_aec(self): | def test_aec(self): | ||||
| # Download audio files | |||||
| download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||||
| download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | model_id = 'damo/speech_dfsmn_aec_psm_16k' | ||||
| input = { | input = { | ||||
| 'nearend_mic': NEAREND_MIC_FILE, | |||||
| 'farend_speech': FAREND_SPEECH_FILE | |||||
| 'nearend_mic': os.path.join(os.getcwd(), NEAREND_MIC_FILE), | |||||
| 'farend_speech': os.path.join(os.getcwd(), FAREND_SPEECH_FILE) | |||||
| } | } | ||||
| aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) | aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) | ||||
| output_path = os.path.abspath('output.wav') | output_path = os.path.abspath('output.wav') | ||||
| @@ -48,14 +33,11 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_aec_bytes(self): | def test_aec_bytes(self): | ||||
| # Download audio files | |||||
| download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||||
| download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | model_id = 'damo/speech_dfsmn_aec_psm_16k' | ||||
| input = {} | input = {} | ||||
| with open(NEAREND_MIC_FILE, 'rb') as f: | |||||
| with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: | |||||
| input['nearend_mic'] = f.read() | input['nearend_mic'] = f.read() | ||||
| with open(FAREND_SPEECH_FILE, 'rb') as f: | |||||
| with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: | |||||
| input['farend_speech'] = f.read() | input['farend_speech'] = f.read() | ||||
| aec = pipeline( | aec = pipeline( | ||||
| Tasks.acoustic_echo_cancellation, | Tasks.acoustic_echo_cancellation, | ||||
| @@ -67,13 +49,10 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_aec_tuple_bytes(self): | def test_aec_tuple_bytes(self): | ||||
| # Download audio files | |||||
| download(NEAREND_MIC_URL, NEAREND_MIC_FILE) | |||||
| download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) | |||||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | model_id = 'damo/speech_dfsmn_aec_psm_16k' | ||||
| with open(NEAREND_MIC_FILE, 'rb') as f: | |||||
| with open(os.path.join(os.getcwd(), NEAREND_MIC_FILE), 'rb') as f: | |||||
| nearend_bytes = f.read() | nearend_bytes = f.read() | ||||
| with open(FAREND_SPEECH_FILE, 'rb') as f: | |||||
| with open(os.path.join(os.getcwd(), FAREND_SPEECH_FILE), 'rb') as f: | |||||
| farend_bytes = f.read() | farend_bytes = f.read() | ||||
| inputs = (nearend_bytes, farend_bytes) | inputs = (nearend_bytes, farend_bytes) | ||||
| aec = pipeline( | aec = pipeline( | ||||
| @@ -86,25 +65,22 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_ans(self): | def test_ans(self): | ||||
| # Download audio files | |||||
| download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | model_id = 'damo/speech_frcrn_ans_cirm_16k' | ||||
| ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) | ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) | ||||
| output_path = os.path.abspath('output.wav') | output_path = os.path.abspath('output.wav') | ||||
| ans(NOISE_SPEECH_FILE, output_path=output_path) | |||||
| ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), | |||||
| output_path=output_path) | |||||
| print(f'Processed audio saved to {output_path}') | print(f'Processed audio saved to {output_path}') | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_ans_bytes(self): | def test_ans_bytes(self): | ||||
| # Download audio files | |||||
| download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | model_id = 'damo/speech_frcrn_ans_cirm_16k' | ||||
| ans = pipeline( | ans = pipeline( | ||||
| Tasks.acoustic_noise_suppression, | Tasks.acoustic_noise_suppression, | ||||
| model=model_id, | model=model_id, | ||||
| pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) | pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) | ||||
| output_path = os.path.abspath('output.wav') | output_path = os.path.abspath('output.wav') | ||||
| with open(NOISE_SPEECH_FILE, 'rb') as f: | |||||
| with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f: | |||||
| data = f.read() | data = f.read() | ||||
| ans(data, output_path=output_path) | ans(data, output_path=output_path) | ||||
| print(f'Processed audio saved to {output_path}') | print(f'Processed audio saved to {output_path}') | ||||