Browse Source

[to #42322933] add customized keywords setting for KWS

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9200613

    * [Add] add KWS code

    * [Fix] fix kws warning

    * [Add] add ROC for KWS

    * [Update] add some code check

    * [Update] refactor kws code, bug fix

    * [Add] add customized keywords setting for KWS

    * [Add] add data/test/audios for KWS
master
shichen.fsc 3 years ago
parent
commit
40cb1043b8
6 changed files with 121 additions and 15 deletions
  1. +1
    -0
      .gitattributes
  2. +0
    -4
      .gitignore
  3. +3
    -0
      data/test/audios/kws_bofangyinyue.wav
  4. +3
    -0
      data/test/audios/kws_xiaoyunxiaoyun.wav
  5. +47
    -3
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  6. +67
    -8
      tests/pipelines/test_key_word_spotting.py

+ 1
- 0
.gitattributes View File

@@ -1,3 +1,4 @@
*.png filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text

+ 0
- 4
.gitignore View File

@@ -124,7 +124,3 @@ replace.sh

# Pytorch
*.pth


# audio
*.wav

+ 3
- 0
data/test/audios/kws_bofangyinyue.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba
size 69110

+ 3
- 0
data/test/audios/kws_xiaoyunxiaoyun.wav View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1
size 297684

+ 47
- 3
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -5,6 +5,8 @@ import stat
import subprocess
from typing import Any, Dict, List

import json

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
@@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):

self._preprocessor = preprocessor
self._model = model
self._keywords = None

if 'keywords' in kwargs.keys():
self._keywords = kwargs['keywords']
print('self._keywords len: ', len(self._keywords))
print('self._keywords: ', self._keywords)

def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
@@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
return rst_dict

def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
opts: str = ''

# setting customized keywords
keywords_json = self._set_customized_keywords()
if len(keywords_json) > 0:
keywords_json_file = os.path.join(inputs['workspace'],
'keyword_custom.json')
with open(keywords_json_file, 'w') as f:
json.dump(keywords_json, f)
opts = '--keyword-custom ' + keywords_json_file

if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
@@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)

if inputs['kws_set'] in ['pos_testsets', 'roc']:
@@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1
@@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1
@@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
threshold_cur += step

return output

def _set_customized_keywords(self) -> Dict[str, Any]:
if self._keywords is not None:
word_list_inputs = self._keywords
word_list = []
for i in range(len(word_list_inputs)):
key = word_list_inputs[i]
new_item = {}
if key.__contains__('keyword'):
name = key['keyword']
new_name: str = ''
for n in range(0, len(name), 1):
new_name += name[n]
new_name += ' '
new_name = new_name.strip()
new_item['name'] = new_name

if key.__contains__('threshold'):
threshold1: float = key['threshold']
new_item['threshold1'] = threshold1

word_list.append(new_item)
out = {'word_list': word_list}
return out
else:
return ''

+ 67
- 8
tests/pipelines/test_key_word_spotting.py View File

@@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level

KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'

POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'

POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
@@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# downloading wav file
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
if not os.path.exists(wav_file_path):
r = requests.get(POS_WAV_URL)
with open(wav_file_path, 'wb') as f:
f.write(r.content)
# get wav file
wav_file_path = POS_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
@@ -91,9 +87,72 @@ class KeyWordSpottingTest(unittest.TestCase):
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav confidence: ', kws_result['confidence'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# get wav file
wav_file_path = BOFANGYINYUE_WAV_FILE

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

# customized keyword if you need.
# full settings eg.
# keywords = [
# {'keyword':'你好电视', 'threshold': 0.008},
# {'keyword':'播放音乐', 'threshold': 0.008}
# ]
keywords = [{'keyword': '播放音乐'}]

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
keywords=keywords)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['播放音乐'],
'detected': True,
'confidence': 0.660368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav_by_customized_keywords keywords: ',
kws_result['keywords'])
print('test_run_with_wav_by_customized_keywords confidence: ',
kws_result['confidence'])
print('test_run_with_wav_by_customized_keywords detected result: ',
kws_result['detected'])
print('test_run_with_wav_by_customized_keywords wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc


Loading…
Cancel
Save