Browse Source

[to #42322933] add some code check

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9122842

    * [Add] add KWS code

* [Update] check code linters and formatter

* [Update] update kws code

* Merge branch 'master' into dev/kws

* [Fix] fix kws warning

* [Add] add ROC for KWS

* [Update] add some code check

* feat: Fix confilct, auto commit by WebIDE

* feat: Fix confilct, auto commit by WebIDE

* Merge branch 'master' into dev/kws

* [Update] refactor kws code

* [Update] refactor kws code

* [Update] refactor kws code, bug fix

* [Update] refactor kws code, bug fix
master
shichen.fsc huangjun.hj 3 years ago
parent
commit
5386748bc4
10 changed files with 1074 additions and 0 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/models/__init__.py
  3. +1
    -0
      modelscope/models/audio/kws/__init__.py
  4. +30
    -0
      modelscope/models/audio/kws/generic_key_word_spotting.py
  5. +1
    -0
      modelscope/pipelines/audio/__init__.py
  6. +449
    -0
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  7. +1
    -0
      modelscope/preprocessors/__init__.py
  8. +253
    -0
      modelscope/preprocessors/kws.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +334
    -0
      tests/pipelines/test_key_word_spotting.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -21,6 +21,7 @@ class Models(object):
sambert_hifi_16k = 'sambert-hifi-16k'
generic_tts_frontend = 'generic-tts-frontend'
hifigan16k = 'hifigan16k'
kws_kwsbp = 'kws-kwsbp'

# multi-modal models
ofa = 'ofa'
@@ -53,6 +54,7 @@ class Pipelines(object):
# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
kws_kwsbp = 'kws-kwsbp'

# multi-modal tasks
image_caption = 'image-caption'
@@ -94,6 +96,7 @@ class Preprocessors(object):
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
wav_to_lists = 'wav-to-lists'

# multi-modal
ofa_image_caption = 'ofa-image-caption'

+ 1
- 0
modelscope/models/__init__.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .audio.kws import GenericKeyWordSpotting
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k
from .base import Model


+ 1
- 0
modelscope/models/audio/kws/__init__.py View File

@@ -0,0 +1 @@
from .generic_key_word_spotting import * # noqa F403

+ 30
- 0
modelscope/models/audio/kws/generic_key_word_spotting.py View File

@@ -0,0 +1,30 @@
import os
from typing import Any, Dict

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['GenericKeyWordSpotting']


@MODELS.register_module(Tasks.key_word_spotting, module_name=Models.kws_kwsbp)
class GenericKeyWordSpotting(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the info of model.

Args:
model_dir (str): the model path.
"""

self.model_cfg = {
'model_workspace': model_dir,
'config_path': os.path.join(model_dir, 'config.yaml')
}

def forward(self) -> Dict[str, Any]:
"""return the info of the model
"""
return self.model_cfg

+ 1
- 0
modelscope/pipelines/audio/__init__.py View File

@@ -1,2 +1,3 @@
from .kws_kwsbp_pipeline import * # noqa F403
from .linear_aec_pipeline import LinearAECPipeline
from .text_to_speech_pipeline import * # noqa F403

+ 449
- 0
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -0,0 +1,449 @@
import io
import os
import shutil
import stat
import subprocess
from typing import Any, Dict, List

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToLists
from modelscope.utils.constant import Tasks

__all__ = ['KeyWordSpottingKwsbpPipeline']


@PIPELINES.register_module(
Tasks.key_word_spotting, module_name=Pipelines.kws_kwsbp)
class KeyWordSpottingKwsbpPipeline(Pipeline):
"""KWS Pipeline - key word spotting decoding
"""

def __init__(self,
config_file: str = None,
model: Model = None,
preprocessor: WavToLists = None,
**kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction
"""

super().__init__(
config_file=config_file,
model=model,
preprocessor=preprocessor,
**kwargs)
assert model is not None, 'kws model should be provided'
assert preprocessor is not None, 'preprocessor is none'

self._preprocessor = preprocessor
self._model = model

def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid'
output = self._preprocessor.forward(self._model.forward(), kws_type,
wav_path)
output = self.forward(output)
rst = self.postprocess(output)
return rst

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding
"""

# will generate kws result into dump/dump.JOB.log
out = self._run_with_kwsbp(inputs)

return out

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the kws results
"""

pos_result_json = {}
neg_result_json = {}

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
self._parse_dump_log(pos_result_json, inputs['pos_dump_path'])
if inputs['kws_set'] in ['neg_testsets', 'roc']:
self._parse_dump_log(neg_result_json, inputs['neg_dump_path'])
"""
result_json format example:
{
"wav_count": 450,
"keywords": ["小云小云"],
"wav_time": 3560.999999,
"detected": [
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
......
],
"detected_count": 429,
"rejected_count": 21,
"rejected": [
"yyy.wav",
"zzz.wav",
......
]
}
"""

rst_dict = {'kws_set': inputs['kws_set']}

# parsing the result of wav
if inputs['kws_set'] == 'wav':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json['detected_count'] == 1:
rst_dict['keywords'] = pos_result_json['keywords']
rst_dict['detected'] = True
wav_file_name = os.path.basename(inputs['pos_wav_path'])
rst_dict['confidence'] = float(pos_result_json['detected'][0]
[wav_file_name]['confidence'])
else:
rst_dict['detected'] = False

# parsing the result of pos_tests
elif inputs['kws_set'] == 'pos_testsets':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json.__contains__('keywords'):
rst_dict['keywords'] = pos_result_json['keywords']

rst_dict['recall'] = round(
pos_result_json['detected_count'] / rst_dict['wav_count'], 6)

if pos_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = pos_result_json['detected_count']
if pos_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = pos_result_json['rejected_count']
if pos_result_json.__contains__('rejected'):
rst_dict['rejected'] = pos_result_json['rejected']

# parsing the result of neg_tests
elif inputs['kws_set'] == 'neg_testsets':
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[
'neg_wav_count']
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6)
if neg_result_json.__contains__('keywords'):
rst_dict['keywords'] = neg_result_json['keywords']

rst_dict['fa_rate'] = 0.0
rst_dict['fa_per_hour'] = 0.0

if neg_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = neg_result_json['detected_count']
rst_dict['fa_rate'] = round(
neg_result_json['detected_count'] / rst_dict['wav_count'],
6)
if neg_result_json.__contains__('wav_time'):
rst_dict['fa_per_hour'] = round(
neg_result_json['detected_count']
/ float(neg_result_json['wav_time'] / 3600), 6)

if neg_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = neg_result_json['rejected_count']

if neg_result_json.__contains__('detected'):
rst_dict['detected'] = neg_result_json['detected']

# parsing the result of roc
elif inputs['kws_set'] == 'roc':
threshold_start = 0.000
threshold_step = 0.001
threshold_end = 1.000

pos_keywords_list = []
neg_keywords_list = []
if pos_result_json.__contains__('keywords'):
pos_keywords_list = pos_result_json['keywords']
if neg_result_json.__contains__('keywords'):
neg_keywords_list = neg_result_json['keywords']

keywords_list = list(set(pos_keywords_list + neg_keywords_list))

pos_result_json['wav_count'] = inputs['pos_wav_count']
neg_result_json['wav_count'] = inputs['neg_wav_count']

if len(keywords_list) > 0:
rst_dict['keywords'] = keywords_list

for index in range(len(rst_dict['keywords'])):
cur_keyword = rst_dict['keywords'][index]
output_list = self._generate_roc_list(
start=threshold_start,
step=threshold_step,
end=threshold_end,
keyword=cur_keyword,
pos_inputs=pos_result_json,
neg_inputs=neg_result_json)

rst_dict[cur_keyword] = output_list

return rst_dict

def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json')

if inputs['kws_set'] == 'wav':
dump_log_path: str = os.path.join(inputs['pos_dump_path'],
'dump.log')
kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)

if inputs['kws_set'] in ['pos_testsets', 'roc']:
data_dir: str = os.listdir(inputs['pos_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['pos_data_path'], i))

j: int = 0
process = []
while j < inputs['pos_num_thread']:
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1

if inputs['kws_set'] in ['neg_testsets', 'roc']:
data_dir: str = os.listdir(inputs['neg_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['neg_data_path'], i))

j: int = 0
process = []
while j < inputs['neg_num_thread']:
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1

return inputs

def _parse_dump_log(self, result_json: Dict[str, Any],
dump_path: str) -> Dict[str, Any]:
dump_dir = os.listdir(dump_path)
for i in dump_dir:
basename = os.path.splitext(os.path.basename(i))[0]
# find dump.JOB.log
if 'dump' in basename:
with open(
os.path.join(dump_path, i), mode='r',
encoding='utf-8') as file:
while 1:
line = file.readline()
if not line:
break
else:
result_json = self._parse_result_log(
line, result_json)

def _parse_result_log(self, line: str,
result_json: Dict[str, Any]) -> Dict[str, Any]:
# valid info
if '[rejected]' in line or '[detected]' in line:
detected_count = 0
rejected_count = 0

if result_json.__contains__('detected_count'):
detected_count = result_json['detected_count']
if result_json.__contains__('rejected_count'):
rejected_count = result_json['rejected_count']

if '[detected]' in line:
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav,
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00,
detected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
keyword = content_list[2].split(':')[1]
confidence = content_list[3].split(':')[1]

keywords_list = []
if result_json.__contains__('keywords'):
keywords_list = result_json['keywords']

if keyword not in keywords_list:
keywords_list.append(keyword)
result_json['keywords'] = keywords_list

keyword_item = {}
keyword_item['confidence'] = confidence
keyword_item['keyword'] = keyword
item = {}
item[file_name] = keyword_item

detected_list = []
if result_json.__contains__('detected'):
detected_list = result_json['detected']

detected_list.append(item)
result_json['detected'] = detected_list

elif '[rejected]' in line:
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav
rejected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
file_name = file_name.strip().replace('\n',
'').replace('\r', '')

rejected_list = []
if result_json.__contains__('rejected'):
rejected_list = result_json['rejected']

rejected_list.append(file_name)
result_json['rejected'] = rejected_list

result_json['detected_count'] = detected_count
result_json['rejected_count'] = rejected_count

elif 'total_proc_time=' in line and 'wav_time=' in line:
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799
wav_total_time = 0
content_list = line.split('), ')
if result_json.__contains__('wav_time'):
wav_total_time = result_json['wav_time']

wav_time_str = content_list[1].split('=')[1]
wav_time_str = wav_time_str.split('(')[0]
wav_time = float(wav_time_str)
wav_time = round(wav_time, 6)

if isinstance(wav_time, float):
wav_total_time += wav_time

result_json['wav_time'] = wav_total_time

return result_json

def _generate_roc_list(self, start: float, step: float, end: float,
keyword: str, pos_inputs: Dict[str, Any],
neg_inputs: Dict[str, Any]) -> Dict[str, Any]:
pos_wav_count = pos_inputs['wav_count']
neg_wav_time = neg_inputs['wav_time']
det_lists = pos_inputs['detected']
fa_lists = neg_inputs['detected']
threshold_cur = start
"""
input det_lists dict
[
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
]

output dict
[
{
"threshold": 0.000,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
{
"threshold": 0.001,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
]
"""

output = []
while threshold_cur <= end:
det_count = 0
fa_count = 0
for index in range(len(det_lists)):
det_item = det_lists[index]
det_wav_item = det_item.get(next(iter(det_item)))
if det_wav_item['keyword'] == keyword:
confidence = float(det_wav_item['confidence'])
if confidence >= threshold_cur:
det_count += 1

for index in range(len(fa_lists)):
fa_item = fa_lists[index]
fa_wav_item = fa_item.get(next(iter(fa_item)))
if fa_wav_item['keyword'] == keyword:
confidence = float(fa_wav_item['confidence'])
if confidence >= threshold_cur:
fa_count += 1

output_item = {
'threshold': round(threshold_cur, 3),
'recall': round(float(det_count / pos_wav_count), 6),
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6)
}
output.append(output_item)

threshold_cur += step

return output

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -5,6 +5,7 @@ from .base import Preprocessor
from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .kws import WavToLists
from .multi_modal import OfaImageCaptionPreprocessor
from .nlp import * # noqa F403
from .text_to_speech import * # noqa F403

+ 253
- 0
modelscope/preprocessors/kws.py View File

@@ -0,0 +1,253 @@
import os
import shutil
import stat
from pathlib import Path
from typing import Any, Dict, List

import yaml

from modelscope.metainfo import Preprocessors
from modelscope.models.base import Model
from modelscope.utils.constant import Fields
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['WavToLists']


@PREPROCESSORS.register_module(
Fields.audio, module_name=Preprocessors.wav_to_lists)
class WavToLists(Preprocessor):
"""generate audio lists file from wav

Args:
workspace (str): store temporarily kws intermedium and result
"""

def __init__(self, workspace: str = None):
# the workspace path
if len(workspace) == 0:
self._workspace = os.path.join(os.getcwd(), '.tmp')
else:
self._workspace = workspace

if not os.path.exists(self._workspace):
os.mkdir(self._workspace)

def __call__(self,
model: Model = None,
kws_type: str = None,
wav_path: List[str] = None) -> Dict[str, Any]:
"""Call functions to load model and wav.

Args:
model (Model): model should be provided
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
Returns:
Dict[str, Any]: the kws result
"""

assert model is not None, 'preprocess kws model should be provided'
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'
assert wav_path[0] is not None or wav_path[
1] is not None, 'preprocess wav_path is invalid'

self._model = model
out = self.forward(self._model.forward(), kws_type, wav_path)
return out

def forward(self, model: Dict[str, Any], kws_type: str,
wav_path: List[str]) -> Dict[str, Any]:
assert len(kws_type) > 0, 'preprocess kws_type is empty'
assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists(
model['config_path']), 'model config.yaml is absent'

inputs = model.copy()

inputs['kws_set'] = kws_type
inputs['workspace'] = self._workspace
if wav_path[0] is not None:
inputs['pos_wav_path'] = wav_path[0]
if wav_path[1] is not None:
inputs['neg_wav_path'] = wav_path[1]

out = self._read_config(inputs)
out = self._generate_wav_lists(out)

return out

def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""read and parse config.yaml to get all model files
"""

assert os.path.exists(
inputs['config_path']), 'model config yaml file does not exist'

config_file = open(inputs['config_path'])
root = yaml.full_load(config_file)
config_file.close()

inputs['cfg_file'] = root['cfg_file']
inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
root['cfg_file'])
inputs['keyword_grammar'] = root['keyword_grammar']
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], root['keyword_grammar'])
inputs['sample_rate'] = str(root['sample_rate'])
inputs['kws_tool'] = root['kws_tool']

if os.path.exists(
os.path.join(inputs['workspace'], inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join(inputs['workspace'],
inputs['kws_tool'])
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/usr/bin',
inputs['kws_tool'])
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool'])

assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp'
os.chmod(inputs['kws_tool_path'],
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH)

self._config_checking(inputs)
return inputs

def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""assemble wav lists
"""

if inputs['kws_set'] == 'wav':
inputs['pos_num_thread'] = 1
wave_scp_content: str = inputs['pos_wav_path'] + '\n'

with open(os.path.join(inputs['pos_data_path'], 'wave.list'),
'a') as f:
f.write(wave_scp_content)

inputs['pos_wav_count'] = 1

if inputs['kws_set'] in ['pos_testsets', 'roc']:
# find all positive wave
wav_list = []
wav_dir = inputs['pos_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)

list_count: int = len(wav_list)
inputs['pos_wav_count'] = list_count

if list_count <= 128:
inputs['pos_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['pos_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

if inputs['kws_set'] in ['neg_testsets', 'roc']:
# find all negative wave
wav_list = []
wav_dir = inputs['neg_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)

list_count: int = len(wav_list)
inputs['neg_wav_count'] = list_count

if list_count <= 128:
inputs['neg_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['neg_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

return inputs

def _recursion_dir_all_wave(self, wav_list,
dir_path: str) -> Dict[str, Any]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
self._recursion_dir_all_wave(wav_list, file_path)

return wav_list

def _config_checking(self, inputs: Dict[str, Any]):

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
inputs['pos_data_path'] = os.path.join(inputs['workspace'],
'pos_data')
if not os.path.exists(inputs['pos_data_path']):
os.mkdir(inputs['pos_data_path'])
else:
shutil.rmtree(inputs['pos_data_path'])
os.mkdir(inputs['pos_data_path'])

inputs['pos_dump_path'] = os.path.join(inputs['workspace'],
'pos_dump')
if not os.path.exists(inputs['pos_dump_path']):
os.mkdir(inputs['pos_dump_path'])
else:
shutil.rmtree(inputs['pos_dump_path'])
os.mkdir(inputs['pos_dump_path'])

if inputs['kws_set'] in ['neg_testsets', 'roc']:
inputs['neg_data_path'] = os.path.join(inputs['workspace'],
'neg_data')
if not os.path.exists(inputs['neg_data_path']):
os.mkdir(inputs['neg_data_path'])
else:
shutil.rmtree(inputs['neg_data_path'])
os.mkdir(inputs['neg_data_path'])

inputs['neg_dump_path'] = os.path.join(inputs['workspace'],
'neg_dump')
if not os.path.exists(inputs['neg_dump_path']):
os.mkdir(inputs['neg_dump_path'])
else:
shutil.rmtree(inputs['neg_dump_path'])
os.mkdir(inputs['neg_dump_path'])

+ 1
- 0
modelscope/utils/constant.py View File

@@ -52,6 +52,7 @@ class Tasks(object):
auto_speech_recognition = 'auto-speech-recognition'
text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process'
key_word_spotting = 'key-word-spotting'

# multi-modal tasks
image_captioning = 'image-captioning'


+ 334
- 0
tests/pipelines/test_key_word_spotting.py View File

@@ -0,0 +1,334 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tarfile
import unittest

import requests

from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.constant import Fields, InputFields, Tasks
from modelscope.utils.test_utils import test_level

KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'

POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE

POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'

NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'


def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class KeyWordSpottingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun'
self.workspace = os.path.join(os.getcwd(), '.tmp')
if not os.path.exists(self.workspace):
os.mkdir(self.workspace)

def tearDown(self) -> None:
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# downloading wav file
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
if not os.path.exists(wav_file_path):
r = requests.get(POS_WAV_URL)
with open(wav_file_path, 'wb') as f:
f.write(r.content)

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['小云小云'],
'detected': True,
'confidence': 0.990368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'pos_testsets'

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('recall'))
"""
kws result json format example:
{
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.759254,
'keywords': ["小云小云"],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': [
'yyy.wav',
'zzz.wav',
......
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_pos_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_pos_testsets recall: ', kws_result['recall'])
print('test_run_with_pos_testsets wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'neg_testsets'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[None, wav_file_path])
self.assertTrue(kws_result.__contains__('fa_rate'))
"""
kws result json format example:
{
'wav_count': 751,
'kws_set': 'neg_testsets',
'wav_time': 3572.180812,
'keywords': ['小云小云'],
'fa_rate': 0.001332,
'fa_per_hour': 1.007788,
'detected_count': 1,
'rejected_count': 750,
'detected': [
{
'6.wav': {
'confidence': '0.321170'
}
}
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_neg_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate'])
print('test_run_with_neg_testsets fa per hour: ',
kws_result['fa_per_hour'])
print('test_run_with_neg_testsets wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'roc'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/
neg_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(neg_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/
pos_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(pos_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[pos_file_path, neg_file_path])
"""
kws result json format example:
{
'kws_set': 'roc',
'keywords': ['小云小云'],
'小云小云': [
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788},
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788},
......
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0}
]
}
"""
if kws_result.__contains__('keywords'):
find_keyword = kws_result['keywords'][0]
print('test_run_with_roc keywords: ', find_keyword)
keyword_list = kws_result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
print(' threshold:', threshold, ' recall:', recall,
' fa_per_hour:', fa_per_hour)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save