Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613 * [Add] add KWS code * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * [Update] refactor kws code, bug fix * [Add] add customized keywords setting for KWS * [Add] add data/test/audios for KWSmaster
| @@ -1,3 +1,4 @@ | |||
| *.png filter=lfs diff=lfs merge=lfs -text | |||
| *.jpg filter=lfs diff=lfs merge=lfs -text | |||
| *.mp4 filter=lfs diff=lfs merge=lfs -text | |||
| *.wav filter=lfs diff=lfs merge=lfs -text | |||
| @@ -124,7 +124,3 @@ replace.sh | |||
| # Pytorch | |||
| *.pth | |||
| # audio | |||
| *.wav | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba | |||
| size 69110 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1 | |||
| size 297684 | |||
| @@ -5,6 +5,8 @@ import stat | |||
| import subprocess | |||
| from typing import Any, Dict, List | |||
| import json | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines.base import Pipeline | |||
| @@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| self._preprocessor = preprocessor | |||
| self._model = model | |||
| self._keywords = None | |||
| if 'keywords' in kwargs.keys(): | |||
| self._keywords = kwargs['keywords'] | |||
| print('self._keywords len: ', len(self._keywords)) | |||
| print('self._keywords: ', self._keywords) | |||
| def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: | |||
| assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | |||
| @@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| return rst_dict | |||
| def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| opts: str = '' | |||
| # setting customized keywords | |||
| keywords_json = self._set_customized_keywords() | |||
| if len(keywords_json) > 0: | |||
| keywords_json_file = os.path.join(inputs['workspace'], | |||
| 'keyword_custom.json') | |||
| with open(keywords_json_file, 'w') as f: | |||
| json.dump(keywords_json, f) | |||
| opts = '--keyword-custom ' + keywords_json_file | |||
| if inputs['kws_set'] == 'roc': | |||
| inputs['keyword_grammar_path'] = os.path.join( | |||
| @@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
| ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | |||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
| os.system(kws_cmd) | |||
| if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||
| @@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
| ' --wave-scp=' + wav_list_path + \ | |||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
| p = subprocess.Popen(kws_cmd, shell=True) | |||
| process.append(p) | |||
| j += 1 | |||
| @@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||
| ' --wave-scp=' + wav_list_path + \ | |||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||
| p = subprocess.Popen(kws_cmd, shell=True) | |||
| process.append(p) | |||
| j += 1 | |||
| @@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| threshold_cur += step | |||
| return output | |||
| def _set_customized_keywords(self) -> Dict[str, Any]: | |||
| if self._keywords is not None: | |||
| word_list_inputs = self._keywords | |||
| word_list = [] | |||
| for i in range(len(word_list_inputs)): | |||
| key = word_list_inputs[i] | |||
| new_item = {} | |||
| if key.__contains__('keyword'): | |||
| name = key['keyword'] | |||
| new_name: str = '' | |||
| for n in range(0, len(name), 1): | |||
| new_name += name[n] | |||
| new_name += ' ' | |||
| new_name = new_name.strip() | |||
| new_item['name'] = new_name | |||
| if key.__contains__('threshold'): | |||
| threshold1: float = key['threshold'] | |||
| new_item['threshold1'] = threshold1 | |||
| word_list.append(new_item) | |||
| out = {'word_list': word_list} | |||
| return out | |||
| else: | |||
| return '' | |||
| @@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level | |||
| KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | |||
| POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' | |||
| POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE | |||
| POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | |||
| BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | |||
| POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | |||
| POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | |||
| @@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase): | |||
| # wav, neg_testsets, pos_testsets, roc | |||
| kws_set = 'wav' | |||
| # downloading wav file | |||
| wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) | |||
| if not os.path.exists(wav_file_path): | |||
| r = requests.get(POS_WAV_URL) | |||
| with open(wav_file_path, 'wb') as f: | |||
| f.write(r.content) | |||
| # get wav file | |||
| wav_file_path = POS_WAV_FILE | |||
| # downloading kwsbp | |||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
| @@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase): | |||
| """ | |||
| if kws_result.__contains__('keywords'): | |||
| print('test_run_with_wav keywords: ', kws_result['keywords']) | |||
| print('test_run_with_wav confidence: ', kws_result['confidence']) | |||
| print('test_run_with_wav detected result: ', kws_result['detected']) | |||
| print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_wav_by_customized_keywords(self): | |||
| # wav, neg_testsets, pos_testsets, roc | |||
| kws_set = 'wav' | |||
| # get wav file | |||
| wav_file_path = BOFANGYINYUE_WAV_FILE | |||
| # downloading kwsbp | |||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||
| if not os.path.exists(kwsbp_file_path): | |||
| r = requests.get(KWSBP_URL) | |||
| with open(kwsbp_file_path, 'wb') as f: | |||
| f.write(r.content) | |||
| model = Model.from_pretrained(self.model_id) | |||
| self.assertTrue(model is not None) | |||
| cfg_preprocessor = dict( | |||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||
| self.assertTrue(preprocessor is not None) | |||
| # customized keyword if you need. | |||
| # full settings eg. | |||
| # keywords = [ | |||
| # {'keyword':'你好电视', 'threshold': 0.008}, | |||
| # {'keyword':'播放音乐', 'threshold': 0.008} | |||
| # ] | |||
| keywords = [{'keyword': '播放音乐'}] | |||
| kwsbp_16k_pipline = pipeline( | |||
| pipeline_name=Pipelines.kws_kwsbp, | |||
| model=model, | |||
| preprocessor=preprocessor, | |||
| keywords=keywords) | |||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||
| kws_result = kwsbp_16k_pipline( | |||
| kws_type=kws_set, wav_path=[wav_file_path, None]) | |||
| self.assertTrue(kws_result.__contains__('detected')) | |||
| """ | |||
| kws result json format example: | |||
| { | |||
| 'wav_count': 1, | |||
| 'kws_set': 'wav', | |||
| 'wav_time': 9.132938, | |||
| 'keywords': ['播放音乐'], | |||
| 'detected': True, | |||
| 'confidence': 0.660368 | |||
| } | |||
| """ | |||
| if kws_result.__contains__('keywords'): | |||
| print('test_run_with_wav_by_customized_keywords keywords: ', | |||
| kws_result['keywords']) | |||
| print('test_run_with_wav_by_customized_keywords confidence: ', | |||
| kws_result['confidence']) | |||
| print('test_run_with_wav_by_customized_keywords detected result: ', | |||
| kws_result['detected']) | |||
| print('test_run_with_wav_by_customized_keywords wave time(seconds): ', | |||
| kws_result['wav_time']) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_pos_testsets(self): | |||
| # wav, neg_testsets, pos_testsets, roc | |||