Browse Source

[to #42322933] aec pipeline修改C++库依赖到MinDAEC

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9563105

    * use MinDAEC instead of cdll

* feat: ANS pipeline can accept bytes as input and adjust processing order to reduce the amount of computation
master
bin.xue yingda.chen 3 years ago
parent
commit
e3bffedb87
4 changed files with 27 additions and 56 deletions
  1. +1
    -0
      modelscope/preprocessors/__init__.py
  2. +25
    -53
      modelscope/preprocessors/audio.py
  3. +1
    -0
      requirements/audio.txt
  4. +0
    -3
      tests/pipelines/test_speech_signal_process.py

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -32,6 +32,7 @@ else:
'base': ['Preprocessor'],
'builder': ['PREPROCESSORS', 'build_preprocessor'],
'common': ['Compose'],
'audio': ['LinearAECAndFbank'],
'asr': ['WavToScp'],
'video': ['ReadVideoData'],
'image': [


+ 25
- 53
modelscope/preprocessors/audio.py View File

@@ -1,58 +1,15 @@
import ctypes
import io
import os
from typing import Any, Dict

import numpy as np
import scipy.io.wavfile as wav
import torch
from numpy.ctypeslib import ndpointer

from modelscope.utils.constant import Fields
from .builder import PREPROCESSORS


def load_wav(path):
samp_rate, data = wav.read(path)
return np.float32(data), samp_rate


def load_library(libaec):
libaec_in_cwd = os.path.join('.', libaec)
if os.path.exists(libaec_in_cwd):
libaec = libaec_in_cwd
mitaec = ctypes.cdll.LoadLibrary(libaec)
fe_process = mitaec.fe_process_inst
fe_process.argtypes = [
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'),
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'), ctypes.c_int,
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'),
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS'),
ndpointer(ctypes.c_float, flags='C_CONTIGUOUS')
]
return fe_process


def do_linear_aec(fe_process, mic, ref, int16range=True):
mic = np.float32(mic)
ref = np.float32(ref)
if len(mic) > len(ref):
mic = mic[:len(ref)]
out_mic = np.zeros_like(mic)
out_linear = np.zeros_like(mic)
out_echo = np.zeros_like(mic)
out_ref = np.zeros_like(mic)
if int16range:
mic /= 32768
ref /= 32768
fe_process(mic, ref, len(mic), out_mic, out_linear, out_echo)
# out_ref not in use here
if int16range:
out_mic *= 32768
out_linear *= 32768
out_echo *= 32768
return out_mic, out_ref, out_linear, out_echo


def load_kaldi_feature_transform(filename):
fp = open(filename, 'r')
all_str = fp.read()
@@ -162,11 +119,12 @@ class LinearAECAndFbank:
SAMPLE_RATE = 16000

def __init__(self, io_config):
import MinDAEC
self.trunc_length = 7200 * self.SAMPLE_RATE
self.linear_aec_delay = io_config['linear_aec_delay']
self.feature = Feature(io_config['fbank_config'],
io_config['feat_type'], io_config['mvn'])
self.mitaec = load_library(io_config['mitaec_library'])
self.mitaec = MinDAEC.load()
self.mask_on_mic = io_config['mask_on'] == 'nearend_mic'

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -175,18 +133,15 @@ class LinearAECAndFbank:
:return: dict with two keys and Tensor values: "base" linear filtered audio,and "feature"
"""
# read files
nearend_mic, fs = load_wav(data['nearend_mic'])
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}'
farend_speech, fs = load_wav(data['farend_speech'])
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}'
nearend_mic, fs = self.load_wav(data['nearend_mic'])
farend_speech, fs = self.load_wav(data['farend_speech'])
if 'nearend_speech' in data:
nearend_speech, fs = load_wav(data['nearend_speech'])
assert fs == self.SAMPLE_RATE, f'The sample rate should be {self.SAMPLE_RATE}'
nearend_speech, fs = self.load_wav(data['nearend_speech'])
else:
nearend_speech = np.zeros_like(nearend_mic)

out_mic, out_ref, out_linear, out_echo = do_linear_aec(
self.mitaec, nearend_mic, farend_speech)
out_mic, out_ref, out_linear, out_echo = self.mitaec.do_linear_aec(
nearend_mic, farend_speech)
# fix 20ms linear aec delay by delaying the target speech
extra_zeros = np.zeros([int(self.linear_aec_delay * fs)])
nearend_speech = np.concatenate([extra_zeros, nearend_speech])
@@ -229,3 +184,20 @@ class LinearAECAndFbank:
base = out_linear
out_data = {'base': base, 'target': nearend_speech, 'feature': feat}
return out_data

@staticmethod
def load_wav(inputs):
import librosa
if isinstance(inputs, bytes):
inputs = io.BytesIO(inputs)
elif isinstance(inputs, str):
pass
else:
raise TypeError(f'Unsupported input type: {type(inputs)}.')
sample_rate, data = wav.read(inputs)
if len(data.shape) > 1:
raise ValueError('modelscope error:The audio must be mono.')
if sample_rate != LinearAECAndFbank.SAMPLE_RATE:
data = librosa.resample(data, sample_rate,
LinearAECAndFbank.SAMPLE_RATE)
return data.astype(np.float32), LinearAECAndFbank.SAMPLE_RATE

+ 1
- 0
requirements/audio.txt View File

@@ -8,6 +8,7 @@ kwsbp
librosa
lxml
matplotlib
MinDAEC
nara_wpe
nltk
# numpy requirements should be declared with tensorflow 1.15 but not here


+ 0
- 3
tests/pipelines/test_speech_signal_process.py View File

@@ -13,9 +13,6 @@ FAREND_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/AEC/
NEAREND_MIC_FILE = 'nearend_mic.wav'
FAREND_SPEECH_FILE = 'farend_speech.wav'

AEC_LIB_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/dependencies/ics_MaaS_AEC_lib_libmitaec_pyio.so'
AEC_LIB_FILE = 'libmitaec_pyio.so'

NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
NOISE_SPEECH_FILE = 'speech_with_noise.wav'



Loading…
Cancel
Save