Browse Source

[to #42322933] add pcm-bytes supported for KWS

kws增加pcm bytes数据类型的支持
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9635439
master
shichen.fsc yingda.chen 3 years ago
parent
commit
c663dd8cf6
5 changed files with 194 additions and 125 deletions
  1. +68
    -31
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  2. +30
    -23
      modelscope/preprocessors/kws.py
  3. +1
    -1
      requirements/audio.txt
  4. +0
    -5
      tests/pipelines/test_automatic_speech_recognition.py
  5. +95
    -65
      tests/pipelines/test_key_word_spotting.py

+ 68
- 31
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -30,7 +30,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def __call__(self, wav_path: Union[List[str], str],
def __call__(self, audio_in: Union[List[str], str, bytes],
**kwargs) -> Dict[str, Any]:
if 'keywords' in kwargs.keys():
self.keywords = kwargs['keywords']
@@ -40,7 +40,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
if self.preprocessor is None:
self.preprocessor = WavToLists()

output = self.preprocessor.forward(self.model.forward(), wav_path)
output = self.preprocessor.forward(self.model.forward(), audio_in)
output = self.forward(output)
rst = self.postprocess(output)
return rst
@@ -49,7 +49,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""Decoding
"""

logger.info(f"Decoding with {inputs['kws_set']} mode ...")
logger.info(f"Decoding with {inputs['kws_type']} mode ...")

# will generate kws result
out = self.run_with_kwsbp(inputs)
@@ -80,60 +80,97 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
pos_kws_list = inputs['pos_kws_list']
if 'neg_kws_list' in inputs:
neg_kws_list = inputs['neg_kws_list']

rst_dict = kws_util.common.parsing_kws_result(
kws_type=inputs['kws_set'],
kws_type=inputs['kws_type'],
pos_list=pos_kws_list,
neg_list=neg_kws_list)

return rst_dict

def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
cmd = {
'sys_dir': inputs['model_workspace'],
'cfg_file': inputs['cfg_file_path'],
'sample_rate': inputs['sample_rate'],
'keyword_custom': ''
}

import kwsbp
import kws_util.common
kws_inference = kwsbp.KwsbpEngine()

# setting customized keywords
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords(
self.keywords)
cmd = {
'sys_dir':
inputs['model_workspace'],
'cfg_file':
inputs['cfg_file_path'],
'sample_rate':
inputs['sample_rate'],
'keyword_custom':
'',
'pcm_data':
None,
'pcm_data_len':
0,
'list_flag':
True,
# setting customized keywords
'customized_keywords':
kws_util.common.generate_customized_keywords(self.keywords)
}

if inputs['kws_type'] == 'pcm':
cmd['pcm_data'] = inputs['pos_data']
cmd['pcm_data_len'] = len(inputs['pos_data'])
cmd['list_flag'] = False

if inputs['kws_set'] == 'roc':
if inputs['kws_type'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json')

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
if inputs['kws_type'] in ['wav', 'pcm', 'pos_testsets', 'roc']:
cmd['wave_scp'] = inputs['pos_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['pos_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
if hasattr(kws_inference, 'inference_new'):
# run and get inference result
result = kws_inference.inference_new(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['pcm_data'],
cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'],
cmd['list_flag'])
else:
# in order to support kwsbp-0.0.1
result = kws_inference.inference(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['sample_rate'],
cmd['num_thread'])

pos_result = json.loads(result)
inputs['pos_kws_list'] = pos_result['kws_list']

if inputs['kws_set'] in ['neg_testsets', 'roc']:
if inputs['kws_type'] in ['neg_testsets', 'roc']:
cmd['wave_scp'] = inputs['neg_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['neg_num_thread']

# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
if hasattr(kws_inference, 'inference_new'):
# run and get inference result
result = kws_inference.inference_new(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['pcm_data'],
cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'],
cmd['list_flag'])
else:
# in order to support kwsbp-0.0.1
result = kws_inference.inference(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['sample_rate'],
cmd['num_thread'])

neg_result = json.loads(result)
inputs['neg_kws_list'] = neg_result['kws_list']



+ 30
- 23
modelscope/preprocessors/kws.py View File

@@ -21,23 +21,26 @@ class WavToLists(Preprocessor):
def __init__(self):
pass

def __call__(self, model: Model, wav_path: Union[List[str],
str]) -> Dict[str, Any]:
def __call__(self, model: Model, audio_in: Union[List[str], str,
bytes]) -> Dict[str, Any]:
"""Call functions to load model and wav.

Args:
model (Model): model should be provided
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
audio_in (Union[List[str], str, bytes]):
audio_in[0] is positive wav path, audio_in[1] is negative wav path;
audio_in (str) is positive wav path;
audio_in (bytes) is audio pcm data;
Returns:
Dict[str, Any]: the kws result
"""

self.model = model
out = self.forward(self.model.forward(), wav_path)
out = self.forward(self.model.forward(), audio_in)
return out

def forward(self, model: Dict[str, Any],
wav_path: Union[List[str], str]) -> Dict[str, Any]:
audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]:
assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists(
@@ -45,22 +48,21 @@ class WavToLists(Preprocessor):

inputs = model.copy()

wav_list = [None, None]
if isinstance(wav_path, str):
wav_list[0] = wav_path
else:
wav_list = wav_path

import kws_util.common
kws_type = kws_util.common.type_checking(wav_list)
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'

inputs['kws_set'] = kws_type
if wav_list[0] is not None:
inputs['pos_wav_path'] = wav_list[0]
if wav_list[1] is not None:
inputs['neg_wav_path'] = wav_list[1]
kws_type = kws_util.common.type_checking(audio_in)
assert kws_type in [
'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc'
], f'kws_type {kws_type} is invalid, please check audio data'

inputs['kws_type'] = kws_type
if kws_type == 'wav':
inputs['pos_wav_path'] = audio_in
elif kws_type == 'pcm':
inputs['pos_data'] = audio_in
if kws_type in ['pos_testsets', 'roc']:
inputs['pos_wav_path'] = audio_in[0]
if kws_type in ['neg_testsets', 'roc']:
inputs['neg_wav_path'] = audio_in[1]

out = self.read_config(inputs)
out = self.generate_wav_lists(out)
@@ -93,7 +95,7 @@ class WavToLists(Preprocessor):
"""
import kws_util.common

if inputs['kws_set'] == 'wav':
if inputs['kws_type'] == 'wav':
wav_list = []
wave_scp_content: str = inputs['pos_wav_path']
wav_list.append(wave_scp_content)
@@ -101,7 +103,12 @@ class WavToLists(Preprocessor):
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1

if inputs['kws_set'] in ['pos_testsets', 'roc']:
if inputs['kws_type'] == 'pcm':
inputs['pos_wav_list'] = ['pcm_data']
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1

if inputs['kws_type'] in ['pos_testsets', 'roc']:
# find all positive wave
wav_list = []
wav_dir = inputs['pos_wav_path']
@@ -116,7 +123,7 @@ class WavToLists(Preprocessor):
else:
inputs['pos_num_thread'] = 128

if inputs['kws_set'] in ['neg_testsets', 'roc']:
if inputs['kws_type'] in ['neg_testsets', 'roc']:
# find all negative wave
wav_list = []
wav_dir = inputs['neg_wav_path']


+ 1
- 1
requirements/audio.txt View File

@@ -4,7 +4,7 @@ espnet>=202204
h5py
inflect
keras
kwsbp
kwsbp>=0.0.2
librosa
lxml
matplotlib


+ 0
- 5
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -30,11 +30,6 @@ TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'


def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class AutomaticSpeechRecognitionTest(unittest.TestCase):
action_info = {
'test_run_with_wav_pytorch': {


+ 95
- 65
tests/pipelines/test_key_word_spotting.py View File

@@ -5,8 +5,11 @@ import tarfile
import unittest
from typing import Any, Dict, List, Union

import numpy as np
import requests
import soundfile

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.logger import get_logger
@@ -27,12 +30,12 @@ NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/n
class KeyWordSpottingTest(unittest.TestCase):
action_info = {
'test_run_with_wav': {
'checking_item': 'kws_list',
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_set':
'kws_type':
'wav',
'kws_list': [{
'keyword': '小云小云',
@@ -42,13 +45,29 @@ class KeyWordSpottingTest(unittest.TestCase):
}]
}
},
'test_run_with_pcm': {
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_type':
'pcm',
'kws_list': [{
'keyword': '小云小云',
'offset': 5.76,
'length': 9.132938,
'confidence': 0.990368
}]
}
},
'test_run_with_wav_by_customized_keywords': {
'checking_item': 'kws_list',
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '播放音乐',
'example': {
'wav_count':
1,
'kws_set':
'kws_type':
'wav',
'kws_list': [{
'keyword': '播放音乐',
@@ -59,10 +78,10 @@ class KeyWordSpottingTest(unittest.TestCase):
}
},
'test_run_with_pos_testsets': {
'checking_item': 'recall',
'checking_item': ['recall'],
'example': {
'wav_count': 450,
'kws_set': 'pos_testsets',
'kws_type': 'pos_testsets',
'wav_time': 3013.75925,
'keywords': ['小云小云'],
'recall': 0.953333,
@@ -72,11 +91,11 @@ class KeyWordSpottingTest(unittest.TestCase):
}
},
'test_run_with_neg_testsets': {
'checking_item': 'fa_rate',
'checking_item': ['fa_rate'],
'example': {
'wav_count':
751,
'kws_set':
'kws_type':
'neg_testsets',
'wav_time':
3572.180813,
@@ -98,10 +117,10 @@ class KeyWordSpottingTest(unittest.TestCase):
}
},
'test_run_with_roc': {
'checking_item': 'keywords',
'checking_item': ['keywords', 0],
'checking_value': '小云小云',
'example': {
'kws_set':
'kws_type':
'roc',
'keywords': ['小云小云'],
'小云小云': [{
@@ -129,21 +148,20 @@ class KeyWordSpottingTest(unittest.TestCase):

def tearDown(self) -> None:
# remove workspace dir (.tmp)
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace, ignore_errors=True)
shutil.rmtree(self.workspace, ignore_errors=True)

def run_pipeline(self,
model_id: str,
wav_path: Union[List[str], str],
audio_in: Union[List[str], str, bytes],
keywords: List[str] = None) -> Dict[str, Any]:
kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)

kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords)
kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)

return kws_result

def print_error(self, functions: str, result: Dict[str, Any]) -> None:
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(ColorCodes.MAGENTA + functions
@@ -153,49 +171,61 @@ class KeyWordSpottingTest(unittest.TestCase):

raise ValueError('kws result is mismatched')

def check_and_print_result(self, functions: str,
result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
checking_item = result[self.action_info[functions]
['checking_item']]
if functions == 'test_run_with_roc':
if checking_item[0] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav_by_customized_keywords':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:'
+ str(threshold) + ' recall:' + str(recall)
+ ' fa_per_hour:' + str(fa_per_hour)
+ ColorCodes.END)
else:
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
result_item = result
check_list = self.action_info[functions]['checking_item']
for check_item in check_list:
result_item = result_item[check_item]
if result_item is None or result_item == 'None':
self.log_error(functions, result)

if self.action_info[functions].__contains__('checking_value'):
check_value = self.action_info[functions]['checking_value']
if result_item != check_value:
self.log_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold)
+ ' recall:' + str(recall) + ' fa_per_hour:'
+ str(fa_per_hour) + ColorCodes.END)
else:
self.print_error(functions, result)
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)

def wav2bytes(self, wav_file) -> bytes:
audio, fs = soundfile.read(wav_file)

# float32 -> int16
audio = np.asarray(audio)
dtype = np.dtype('int16')
i = np.iinfo(dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)

# int16(PCM_16) -> byte
audio = audio.tobytes()
return audio

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=POS_WAV_FILE)
self.check_and_print_result('test_run_with_wav', kws_result)
model_id=self.model_id, audio_in=POS_WAV_FILE)
self.check_result('test_run_with_wav', kws_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm(self):
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))

kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
self.check_result('test_run_with_pcm', kws_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
@@ -203,32 +233,32 @@ class KeyWordSpottingTest(unittest.TestCase):

kws_result = self.run_pipeline(
model_id=self.model_id,
wav_path=BOFANGYINYUE_WAV_FILE,
audio_in=BOFANGYINYUE_WAV_FILE,
keywords=keywords)
self.check_and_print_result('test_run_with_wav_by_customized_keywords',
kws_result)
self.check_result('test_run_with_wav_by_customized_keywords',
kws_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
wav_path = [wav_file_path, None]
audio_list = [wav_file_path, None]

kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_pos_testsets', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_pos_testsets', kws_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [None, wav_file_path]
audio_list = [None, wav_file_path]

kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_neg_testsets', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_neg_testsets', kws_result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self):
@@ -238,11 +268,11 @@ class KeyWordSpottingTest(unittest.TestCase):
neg_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [pos_file_path, neg_file_path]
audio_list = [pos_file_path, neg_file_path]

kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_roc', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_roc', kws_result)


if __name__ == '__main__':


Loading…
Cancel
Save