Browse Source

[to #42322933] add customized keywords setting for KWS

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613

    * [Add] add KWS code

    * [Fix] fix kws warning

    * [Add] add ROC for KWS

    * [Update] add some code check

    * [Update] refactor kws code, bug fix

    * [Add] add customized keywords setting for KWS

    * [Add] add data/test/audios for KWS
master
shichen.fsc 3 years ago
parent
commit
40cb1043b8
6 changed files with 121 additions and 15 deletions
  1. +1
    -0
      .gitattributes
  2. +0
    -4
      .gitignore
  3. +3
    -0
      data/test/audios/kws_bofangyinyue.wav
  4. +3
    -0
      data/test/audios/kws_xiaoyunxiaoyun.wav
  5. +47
    -3
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  6. +67
    -8
      tests/pipelines/test_key_word_spotting.py

+ 1
- 0
.gitattributes View File

@@ -1,3 +1,4 @@
*.png filter=lfs diff=lfs merge=lfs -text *.png filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text

+ 0
- 4
.gitignore View File

@@ -124,7 +124,3 @@ replace.sh


# Pytorch # Pytorch
*.pth *.pth


# audio
*.wav

+ 3
- 0
data/test/audios/kws_bofangyinyue.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba
size 69110

+ 3
- 0
data/test/audios/kws_xiaoyunxiaoyun.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1
size 297684

+ 47
- 3
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -5,6 +5,8 @@ import stat
import subprocess import subprocess
from typing import Any, Dict, List from typing import Any, Dict, List


import json

from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models import Model from modelscope.models import Model
from modelscope.pipelines.base import Pipeline from modelscope.pipelines.base import Pipeline
@@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):


self._preprocessor = preprocessor self._preprocessor = preprocessor
self._model = model self._model = model
self._keywords = None

if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']
print('self._keywords len: ', len(self._keywords))
print('self._keywords: ', self._keywords)


def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
@@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
return rst_dict return rst_dict


def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
opts: str = ''

# setting customized keywords
keywords_json = self._set_customized_keywords()
if len(keywords_json) > 0:
keywords_json_file = os.path.join(inputs['workspace'],
'keyword_custom.json')
with open(keywords_json_file, 'w') as f:
json.dump(keywords_json, f)
opts = '--keyword-custom ' + keywords_json_file


if inputs['kws_set'] == 'roc': if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join( inputs['keyword_grammar_path'] = os.path.join(
@@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \ ' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd) os.system(kws_cmd)


if inputs['kws_set'] in ['pos_testsets', 'roc']: if inputs['kws_set'] in ['pos_testsets', 'roc']:
@@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \ ' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \ ' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True) p = subprocess.Popen(kws_cmd, shell=True)
process.append(p) process.append(p)
j += 1 j += 1
@@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \ ' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \ ' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True) p = subprocess.Popen(kws_cmd, shell=True)
process.append(p) process.append(p)
j += 1 j += 1
@@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
threshold_cur += step threshold_cur += step


return output return output

def _set_customized_keywords(self) -> Dict[str, Any]:
if self._keywords is not None:
word_list_inputs = self._keywords
word_list = []
for i in range(len(word_list_inputs)):
key = word_list_inputs[i]
new_item = {}
if key.__contains__('keyword'):
name = key['keyword']
new_name: str = ''
for n in range(0, len(name), 1):
new_name += name[n]
new_name += ' '
new_name = new_name.strip()
new_item['name'] = new_name

if key.__contains__('threshold'):
threshold1: float = key['threshold']
new_item['threshold1'] = threshold1

word_list.append(new_item)
out = {'word_list': word_list}
return out
else:
return ''

+ 67
- 8
tests/pipelines/test_key_word_spotting.py View File

@@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level


KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'


POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'


POS_TESTSETS_FILE = 'pos_testsets.tar.gz' POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
@@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase):
# wav, neg_testsets, pos_testsets, roc # wav, neg_testsets, pos_testsets, roc
kws_set = 'wav' kws_set = 'wav'


# downloading wav file
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
if not os.path.exists(wav_file_path):
r = requests.get(POS_WAV_URL)
with open(wav_file_path, 'wb') as f:
f.write(r.content)
# get wav file
wav_file_path = POS_WAV_FILE


# downloading kwsbp # downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
@@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase):
""" """
if kws_result.__contains__('keywords'): if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords']) print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav confidence: ', kws_result['confidence'])
print('test_run_with_wav detected result: ', kws_result['detected']) print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = BOFANGYINYUE_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

# customized keyword if you need.
# full settings eg.
# keywords = [
# {'keyword':'你好电视', 'threshold': 0.008},
# {'keyword':'播放音乐', 'threshold': 0.008}
# ]
keywords = [{'keyword': '播放音乐'}]

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['播放音乐'],
'detected': True,
'confidence': 0.660368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav_by_customized_keywords keywords: ',
kws_result['keywords'])
print('test_run_with_wav_by_customized_keywords confidence: ',
kws_result['confidence'])
print('test_run_with_wav_by_customized_keywords detected result: ',
kws_result['detected'])
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self): def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc # wav, neg_testsets, pos_testsets, roc


Loading…
Cancel
Save