Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613 * [Add] add KWS code * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * [Update] refactor kws code, bug fix * [Add] add customized keywords setting for KWS * [Add] add data/test/audios for KWSmaster
| @@ -1,3 +1,4 @@ | |||||
| *.png filter=lfs diff=lfs merge=lfs -text | *.png filter=lfs diff=lfs merge=lfs -text | ||||
| *.jpg filter=lfs diff=lfs merge=lfs -text | *.jpg filter=lfs diff=lfs merge=lfs -text | ||||
| *.mp4 filter=lfs diff=lfs merge=lfs -text | *.mp4 filter=lfs diff=lfs merge=lfs -text | ||||
| *.wav filter=lfs diff=lfs merge=lfs -text | |||||
| @@ -124,7 +124,3 @@ replace.sh | |||||
| # Pytorch | # Pytorch | ||||
| *.pth | *.pth | ||||
| # audio | |||||
| *.wav | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba | |||||
| size 69110 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1 | |||||
| size 297684 | |||||
| @@ -5,6 +5,8 @@ import stat | |||||
| import subprocess | import subprocess | ||||
| from typing import Any, Dict, List | from typing import Any, Dict, List | ||||
| import json | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| @@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| self._preprocessor = preprocessor | self._preprocessor = preprocessor | ||||
| self._model = model | self._model = model | ||||
| self._keywords = None | |||||
| if 'keywords' in kwargs.keys(): | |||||
| self._keywords = kwargs['keywords'] | |||||
| print('self._keywords len: ', len(self._keywords)) | |||||
| print('self._keywords: ', self._keywords) | |||||
| def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: | def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: | ||||
| assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | ||||
| @@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| return rst_dict | return rst_dict | ||||
| def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| opts: str = '' | |||||
| # setting customized keywords | |||||
| keywords_json = self._set_customized_keywords() | |||||
| if len(keywords_json) > 0: | |||||
| keywords_json_file = os.path.join(inputs['workspace'], | |||||
| 'keyword_custom.json') | |||||
| with open(keywords_json_file, 'w') as f: | |||||
| json.dump(keywords_json, f) | |||||
| opts = '--keyword-custom ' + keywords_json_file | |||||
| if inputs['kws_set'] == 'roc': | if inputs['kws_set'] == 'roc': | ||||
| inputs['keyword_grammar_path'] = os.path.join( | inputs['keyword_grammar_path'] = os.path.join( | ||||
| @@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | ' --sample-rate=' + inputs['sample_rate'] + \ | ||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ||||
| ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | ||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
| os.system(kws_cmd) | os.system(kws_cmd) | ||||
| if inputs['kws_set'] in ['pos_testsets', 'roc']: | if inputs['kws_set'] in ['pos_testsets', 'roc']: | ||||
| @@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | ' --sample-rate=' + inputs['sample_rate'] + \ | ||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ||||
| ' --wave-scp=' + wav_list_path + \ | ' --wave-scp=' + wav_list_path + \ | ||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
| p = subprocess.Popen(kws_cmd, shell=True) | p = subprocess.Popen(kws_cmd, shell=True) | ||||
| process.append(p) | process.append(p) | ||||
| j += 1 | j += 1 | ||||
| @@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | ' --sample-rate=' + inputs['sample_rate'] + \ | ||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | ||||
| ' --wave-scp=' + wav_list_path + \ | ' --wave-scp=' + wav_list_path + \ | ||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' | |||||
| p = subprocess.Popen(kws_cmd, shell=True) | p = subprocess.Popen(kws_cmd, shell=True) | ||||
| process.append(p) | process.append(p) | ||||
| j += 1 | j += 1 | ||||
| @@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| threshold_cur += step | threshold_cur += step | ||||
| return output | return output | ||||
| def _set_customized_keywords(self) -> Dict[str, Any]: | |||||
| if self._keywords is not None: | |||||
| word_list_inputs = self._keywords | |||||
| word_list = [] | |||||
| for i in range(len(word_list_inputs)): | |||||
| key = word_list_inputs[i] | |||||
| new_item = {} | |||||
| if key.__contains__('keyword'): | |||||
| name = key['keyword'] | |||||
| new_name: str = '' | |||||
| for n in range(0, len(name), 1): | |||||
| new_name += name[n] | |||||
| new_name += ' ' | |||||
| new_name = new_name.strip() | |||||
| new_item['name'] = new_name | |||||
| if key.__contains__('threshold'): | |||||
| threshold1: float = key['threshold'] | |||||
| new_item['threshold1'] = threshold1 | |||||
| word_list.append(new_item) | |||||
| out = {'word_list': word_list} | |||||
| return out | |||||
| else: | |||||
| return '' | |||||
| @@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level | |||||
| KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | ||||
| POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' | |||||
| POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE | |||||
| POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | |||||
| BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | |||||
| POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | ||||
| POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | ||||
| @@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| # wav, neg_testsets, pos_testsets, roc | # wav, neg_testsets, pos_testsets, roc | ||||
| kws_set = 'wav' | kws_set = 'wav' | ||||
| # downloading wav file | |||||
| wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) | |||||
| if not os.path.exists(wav_file_path): | |||||
| r = requests.get(POS_WAV_URL) | |||||
| with open(wav_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| # get wav file | |||||
| wav_file_path = POS_WAV_FILE | |||||
| # downloading kwsbp | # downloading kwsbp | ||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | ||||
| @@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| """ | """ | ||||
| if kws_result.__contains__('keywords'): | if kws_result.__contains__('keywords'): | ||||
| print('test_run_with_wav keywords: ', kws_result['keywords']) | print('test_run_with_wav keywords: ', kws_result['keywords']) | ||||
| print('test_run_with_wav confidence: ', kws_result['confidence']) | |||||
| print('test_run_with_wav detected result: ', kws_result['detected']) | print('test_run_with_wav detected result: ', kws_result['detected']) | ||||
| print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_wav_by_customized_keywords(self): | |||||
| # wav, neg_testsets, pos_testsets, roc | |||||
| kws_set = 'wav' | |||||
| # get wav file | |||||
| wav_file_path = BOFANGYINYUE_WAV_FILE | |||||
| # downloading kwsbp | |||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
| if not os.path.exists(kwsbp_file_path): | |||||
| r = requests.get(KWSBP_URL) | |||||
| with open(kwsbp_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.assertTrue(model is not None) | |||||
| cfg_preprocessor = dict( | |||||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||||
| self.assertTrue(preprocessor is not None) | |||||
| # customized keyword if you need. | |||||
| # full settings eg. | |||||
| # keywords = [ | |||||
| # {'keyword':'你好电视', 'threshold': 0.008}, | |||||
| # {'keyword':'播放音乐', 'threshold': 0.008} | |||||
| # ] | |||||
| keywords = [{'keyword': '播放音乐'}] | |||||
| kwsbp_16k_pipline = pipeline( | |||||
| pipeline_name=Pipelines.kws_kwsbp, | |||||
| model=model, | |||||
| preprocessor=preprocessor, | |||||
| keywords=keywords) | |||||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||||
| kws_result = kwsbp_16k_pipline( | |||||
| kws_type=kws_set, wav_path=[wav_file_path, None]) | |||||
| self.assertTrue(kws_result.__contains__('detected')) | |||||
| """ | |||||
| kws result json format example: | |||||
| { | |||||
| 'wav_count': 1, | |||||
| 'kws_set': 'wav', | |||||
| 'wav_time': 9.132938, | |||||
| 'keywords': ['播放音乐'], | |||||
| 'detected': True, | |||||
| 'confidence': 0.660368 | |||||
| } | |||||
| """ | |||||
| if kws_result.__contains__('keywords'): | |||||
| print('test_run_with_wav_by_customized_keywords keywords: ', | |||||
| kws_result['keywords']) | |||||
| print('test_run_with_wav_by_customized_keywords confidence: ', | |||||
| kws_result['confidence']) | |||||
| print('test_run_with_wav_by_customized_keywords detected result: ', | |||||
| kws_result['detected']) | |||||
| print('test_run_with_wav_by_customized_keywords wave time(seconds): ', | |||||
| kws_result['wav_time']) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_pos_testsets(self): | def test_run_with_pos_testsets(self): | ||||
| # wav, neg_testsets, pos_testsets, roc | # wav, neg_testsets, pos_testsets, roc | ||||