diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index a6cc4d55..5d51593e 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -30,7 +30,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): """ super().__init__(model=model, preprocessor=preprocessor, **kwargs) - def __call__(self, wav_path: Union[List[str], str], + def __call__(self, audio_in: Union[List[str], str, bytes], **kwargs) -> Dict[str, Any]: if 'keywords' in kwargs.keys(): self.keywords = kwargs['keywords'] @@ -40,7 +40,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): if self.preprocessor is None: self.preprocessor = WavToLists() - output = self.preprocessor.forward(self.model.forward(), wav_path) + output = self.preprocessor.forward(self.model.forward(), audio_in) output = self.forward(output) rst = self.postprocess(output) return rst @@ -49,7 +49,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): """Decoding """ - logger.info(f"Decoding with {inputs['kws_set']} mode ...") + logger.info(f"Decoding with {inputs['kws_type']} mode ...") # will generate kws result out = self.run_with_kwsbp(inputs) @@ -80,60 +80,97 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): pos_kws_list = inputs['pos_kws_list'] if 'neg_kws_list' in inputs: neg_kws_list = inputs['neg_kws_list'] + rst_dict = kws_util.common.parsing_kws_result( - kws_type=inputs['kws_set'], + kws_type=inputs['kws_type'], pos_list=pos_kws_list, neg_list=neg_kws_list) return rst_dict def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - cmd = { - 'sys_dir': inputs['model_workspace'], - 'cfg_file': inputs['cfg_file_path'], - 'sample_rate': inputs['sample_rate'], - 'keyword_custom': '' - } - import kwsbp import kws_util.common kws_inference = kwsbp.KwsbpEngine() - # setting customized keywords - cmd['customized_keywords'] = kws_util.common.generate_customized_keywords( - self.keywords) + cmd = { + 'sys_dir': + inputs['model_workspace'], + 'cfg_file': + inputs['cfg_file_path'], + 'sample_rate': + inputs['sample_rate'], + 'keyword_custom': + '', + 'pcm_data': + None, + 'pcm_data_len': + 0, + 'list_flag': + True, + # setting customized keywords + 'customized_keywords': + kws_util.common.generate_customized_keywords(self.keywords) + } + + if inputs['kws_type'] == 'pcm': + cmd['pcm_data'] = inputs['pos_data'] + cmd['pcm_data_len'] = len(inputs['pos_data']) + cmd['list_flag'] = False - if inputs['kws_set'] == 'roc': + if inputs['kws_type'] == 'roc': inputs['keyword_grammar_path'] = os.path.join( inputs['model_workspace'], 'keywords_roc.json') - if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + if inputs['kws_type'] in ['wav', 'pcm', 'pos_testsets', 'roc']: cmd['wave_scp'] = inputs['pos_wav_list'] cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] cmd['num_thread'] = inputs['pos_num_thread'] - # run and get inference result - result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], - cmd['keyword_grammar_path'], - str(json.dumps(cmd['wave_scp'])), - str(cmd['customized_keywords']), - cmd['sample_rate'], - cmd['num_thread']) + if hasattr(kws_inference, 'inference_new'): + # run and get inference result + result = kws_inference.inference_new( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['pcm_data'], + cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'], + cmd['list_flag']) + else: + # in order to support kwsbp-0.0.1 + result = kws_inference.inference( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['sample_rate'], + cmd['num_thread']) + pos_result = json.loads(result) inputs['pos_kws_list'] = pos_result['kws_list'] - if inputs['kws_set'] in ['neg_testsets', 'roc']: + if inputs['kws_type'] in ['neg_testsets', 'roc']: cmd['wave_scp'] = inputs['neg_wav_list'] cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] cmd['num_thread'] = inputs['neg_num_thread'] - # run and get inference result - result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'], - cmd['keyword_grammar_path'], - str(json.dumps(cmd['wave_scp'])), - str(cmd['customized_keywords']), - cmd['sample_rate'], - cmd['num_thread']) + if hasattr(kws_inference, 'inference_new'): + # run and get inference result + result = kws_inference.inference_new( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['pcm_data'], + cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'], + cmd['list_flag']) + else: + # in order to support kwsbp-0.0.1 + result = kws_inference.inference( + cmd['sys_dir'], cmd['cfg_file'], + cmd['keyword_grammar_path'], + str(json.dumps(cmd['wave_scp'])), + str(cmd['customized_keywords']), cmd['sample_rate'], + cmd['num_thread']) + neg_result = json.loads(result) inputs['neg_kws_list'] = neg_result['kws_list'] diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py index b406465a..9c370ed5 100644 --- a/modelscope/preprocessors/kws.py +++ b/modelscope/preprocessors/kws.py @@ -21,23 +21,26 @@ class WavToLists(Preprocessor): def __init__(self): pass - def __call__(self, model: Model, wav_path: Union[List[str], - str]) -> Dict[str, Any]: + def __call__(self, model: Model, audio_in: Union[List[str], str, + bytes]) -> Dict[str, Any]: """Call functions to load model and wav. Args: model (Model): model should be provided - wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path + audio_in (Union[List[str], str, bytes]): + audio_in[0] is positive wav path, audio_in[1] is negative wav path; + audio_in (str) is positive wav path; + audio_in (bytes) is audio pcm data; Returns: Dict[str, Any]: the kws result """ self.model = model - out = self.forward(self.model.forward(), wav_path) + out = self.forward(self.model.forward(), audio_in) return out def forward(self, model: Dict[str, Any], - wav_path: Union[List[str], str]) -> Dict[str, Any]: + audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]: assert len( model['config_path']) > 0, 'preprocess model[config_path] is empty' assert os.path.exists( @@ -45,22 +48,21 @@ class WavToLists(Preprocessor): inputs = model.copy() - wav_list = [None, None] - if isinstance(wav_path, str): - wav_list[0] = wav_path - else: - wav_list = wav_path - import kws_util.common - kws_type = kws_util.common.type_checking(wav_list) - assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' - ], f'preprocess kws_type {kws_type} is invalid' - - inputs['kws_set'] = kws_type - if wav_list[0] is not None: - inputs['pos_wav_path'] = wav_list[0] - if wav_list[1] is not None: - inputs['neg_wav_path'] = wav_list[1] + kws_type = kws_util.common.type_checking(audio_in) + assert kws_type in [ + 'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc' + ], f'kws_type {kws_type} is invalid, please check audio data' + + inputs['kws_type'] = kws_type + if kws_type == 'wav': + inputs['pos_wav_path'] = audio_in + elif kws_type == 'pcm': + inputs['pos_data'] = audio_in + if kws_type in ['pos_testsets', 'roc']: + inputs['pos_wav_path'] = audio_in[0] + if kws_type in ['neg_testsets', 'roc']: + inputs['neg_wav_path'] = audio_in[1] out = self.read_config(inputs) out = self.generate_wav_lists(out) @@ -93,7 +95,7 @@ class WavToLists(Preprocessor): """ import kws_util.common - if inputs['kws_set'] == 'wav': + if inputs['kws_type'] == 'wav': wav_list = [] wave_scp_content: str = inputs['pos_wav_path'] wav_list.append(wave_scp_content) @@ -101,7 +103,12 @@ class WavToLists(Preprocessor): inputs['pos_wav_count'] = 1 inputs['pos_num_thread'] = 1 - if inputs['kws_set'] in ['pos_testsets', 'roc']: + if inputs['kws_type'] == 'pcm': + inputs['pos_wav_list'] = ['pcm_data'] + inputs['pos_wav_count'] = 1 + inputs['pos_num_thread'] = 1 + + if inputs['kws_type'] in ['pos_testsets', 'roc']: # find all positive wave wav_list = [] wav_dir = inputs['pos_wav_path'] @@ -116,7 +123,7 @@ class WavToLists(Preprocessor): else: inputs['pos_num_thread'] = 128 - if inputs['kws_set'] in ['neg_testsets', 'roc']: + if inputs['kws_type'] in ['neg_testsets', 'roc']: # find all negative wave wav_list = [] wav_dir = inputs['neg_wav_path'] diff --git a/requirements/audio.txt b/requirements/audio.txt index 132b48ed..81d288bd 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -4,7 +4,7 @@ espnet>=202204 h5py inflect keras -kwsbp +kwsbp>=0.0.2 librosa lxml matplotlib diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py index 9dad7573..1843d5dd 100644 --- a/tests/pipelines/test_automatic_speech_recognition.py +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -30,11 +30,6 @@ TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' -def un_tar_gz(fname, dirs): - t = tarfile.open(fname) - t.extractall(path=dirs) - - class AutomaticSpeechRecognitionTest(unittest.TestCase): action_info = { 'test_run_with_wav_pytorch': { diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 8b0e37e6..5b7d20d0 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -5,8 +5,11 @@ import tarfile import unittest from typing import Any, Dict, List, Union +import numpy as np import requests +import soundfile +from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.utils.constant import ColorCodes, Tasks from modelscope.utils.logger import get_logger @@ -27,12 +30,12 @@ NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/n class KeyWordSpottingTest(unittest.TestCase): action_info = { 'test_run_with_wav': { - 'checking_item': 'kws_list', + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], 'checking_value': '小云小云', 'example': { 'wav_count': 1, - 'kws_set': + 'kws_type': 'wav', 'kws_list': [{ 'keyword': '小云小云', @@ -42,13 +45,29 @@ class KeyWordSpottingTest(unittest.TestCase): }] } }, + 'test_run_with_pcm': { + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], + 'checking_value': '小云小云', + 'example': { + 'wav_count': + 1, + 'kws_type': + 'pcm', + 'kws_list': [{ + 'keyword': '小云小云', + 'offset': 5.76, + 'length': 9.132938, + 'confidence': 0.990368 + }] + } + }, 'test_run_with_wav_by_customized_keywords': { - 'checking_item': 'kws_list', + 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], 'checking_value': '播放音乐', 'example': { 'wav_count': 1, - 'kws_set': + 'kws_type': 'wav', 'kws_list': [{ 'keyword': '播放音乐', @@ -59,10 +78,10 @@ class KeyWordSpottingTest(unittest.TestCase): } }, 'test_run_with_pos_testsets': { - 'checking_item': 'recall', + 'checking_item': ['recall'], 'example': { 'wav_count': 450, - 'kws_set': 'pos_testsets', + 'kws_type': 'pos_testsets', 'wav_time': 3013.75925, 'keywords': ['小云小云'], 'recall': 0.953333, @@ -72,11 +91,11 @@ class KeyWordSpottingTest(unittest.TestCase): } }, 'test_run_with_neg_testsets': { - 'checking_item': 'fa_rate', + 'checking_item': ['fa_rate'], 'example': { 'wav_count': 751, - 'kws_set': + 'kws_type': 'neg_testsets', 'wav_time': 3572.180813, @@ -98,10 +117,10 @@ class KeyWordSpottingTest(unittest.TestCase): } }, 'test_run_with_roc': { - 'checking_item': 'keywords', + 'checking_item': ['keywords', 0], 'checking_value': '小云小云', 'example': { - 'kws_set': + 'kws_type': 'roc', 'keywords': ['小云小云'], '小云小云': [{ @@ -129,21 +148,20 @@ class KeyWordSpottingTest(unittest.TestCase): def tearDown(self) -> None: # remove workspace dir (.tmp) - if os.path.exists(self.workspace): - shutil.rmtree(self.workspace, ignore_errors=True) + shutil.rmtree(self.workspace, ignore_errors=True) def run_pipeline(self, model_id: str, - wav_path: Union[List[str], str], + audio_in: Union[List[str], str, bytes], keywords: List[str] = None) -> Dict[str, Any]: kwsbp_16k_pipline = pipeline( task=Tasks.auto_speech_recognition, model=model_id) - kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords) + kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords) return kws_result - def print_error(self, functions: str, result: Dict[str, Any]) -> None: + def log_error(self, functions: str, result: Dict[str, Any]) -> None: logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' + ColorCodes.END) logger.error(ColorCodes.MAGENTA + functions @@ -153,49 +171,61 @@ class KeyWordSpottingTest(unittest.TestCase): raise ValueError('kws result is mismatched') - def check_and_print_result(self, functions: str, - result: Dict[str, Any]) -> None: - if result.__contains__(self.action_info[functions]['checking_item']): - checking_item = result[self.action_info[functions] - ['checking_item']] - if functions == 'test_run_with_roc': - if checking_item[0] != self.action_info[functions][ - 'checking_value']: - self.print_error(functions, result) - - elif functions == 'test_run_with_wav': - if checking_item[0]['keyword'] != self.action_info[functions][ - 'checking_value']: - self.print_error(functions, result) - - elif functions == 'test_run_with_wav_by_customized_keywords': - if checking_item[0]['keyword'] != self.action_info[functions][ - 'checking_value']: - self.print_error(functions, result) - - logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' - + ColorCodes.END) - if functions == 'test_run_with_roc': - find_keyword = result['keywords'][0] - keyword_list = result[find_keyword] - for item in iter(keyword_list): - threshold: float = item['threshold'] - recall: float = item['recall'] - fa_per_hour: float = item['fa_per_hour'] - logger.info(ColorCodes.YELLOW + ' threshold:' - + str(threshold) + ' recall:' + str(recall) - + ' fa_per_hour:' + str(fa_per_hour) - + ColorCodes.END) - else: - logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) + def check_result(self, functions: str, result: Dict[str, Any]) -> None: + result_item = result + check_list = self.action_info[functions]['checking_item'] + for check_item in check_list: + result_item = result_item[check_item] + if result_item is None or result_item == 'None': + self.log_error(functions, result) + + if self.action_info[functions].__contains__('checking_value'): + check_value = self.action_info[functions]['checking_value'] + if result_item != check_value: + self.log_error(functions, result) + + logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' + + ColorCodes.END) + if functions == 'test_run_with_roc': + find_keyword = result['keywords'][0] + keyword_list = result[find_keyword] + for item in iter(keyword_list): + threshold: float = item['threshold'] + recall: float = item['recall'] + fa_per_hour: float = item['fa_per_hour'] + logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold) + + ' recall:' + str(recall) + ' fa_per_hour:' + + str(fa_per_hour) + ColorCodes.END) else: - self.print_error(functions, result) + logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) + + def wav2bytes(self, wav_file) -> bytes: + audio, fs = soundfile.read(wav_file) + + # float32 -> int16 + audio = np.asarray(audio) + dtype = np.dtype('int16') + i = np.iinfo(dtype) + abs_max = 2**(i.bits - 1) + offset = i.min + abs_max + audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) + + # int16(PCM_16) -> byte + audio = audio.tobytes() + return audio @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav(self): kws_result = self.run_pipeline( - model_id=self.model_id, wav_path=POS_WAV_FILE) - self.check_and_print_result('test_run_with_wav', kws_result) + model_id=self.model_id, audio_in=POS_WAV_FILE) + self.check_result('test_run_with_wav', kws_result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_pcm(self): + audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE)) + + kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio) + self.check_result('test_run_with_pcm', kws_result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_by_customized_keywords(self): @@ -203,32 +233,32 @@ class KeyWordSpottingTest(unittest.TestCase): kws_result = self.run_pipeline( model_id=self.model_id, - wav_path=BOFANGYINYUE_WAV_FILE, + audio_in=BOFANGYINYUE_WAV_FILE, keywords=keywords) - self.check_and_print_result('test_run_with_wav_by_customized_keywords', - kws_result) + self.check_result('test_run_with_wav_by_customized_keywords', + kws_result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): wav_file_path = download_and_untar( os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, self.workspace) - wav_path = [wav_file_path, None] + audio_list = [wav_file_path, None] kws_result = self.run_pipeline( - model_id=self.model_id, wav_path=wav_path) - self.check_and_print_result('test_run_with_pos_testsets', kws_result) + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_pos_testsets', kws_result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_neg_testsets(self): wav_file_path = download_and_untar( os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, self.workspace) - wav_path = [None, wav_file_path] + audio_list = [None, wav_file_path] kws_result = self.run_pipeline( - model_id=self.model_id, wav_path=wav_path) - self.check_and_print_result('test_run_with_neg_testsets', kws_result) + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_neg_testsets', kws_result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_roc(self): @@ -238,11 +268,11 @@ class KeyWordSpottingTest(unittest.TestCase): neg_file_path = download_and_untar( os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, self.workspace) - wav_path = [pos_file_path, neg_file_path] + audio_list = [pos_file_path, neg_file_path] kws_result = self.run_pipeline( - model_id=self.model_id, wav_path=wav_path) - self.check_and_print_result('test_run_with_roc', kws_result) + model_id=self.model_id, audio_in=audio_list) + self.check_result('test_run_with_roc', kws_result) if __name__ == '__main__':