Browse Source

[to #42322933] add pcm-bytes supported for KWS

kws增加pcm bytes数据类型的支持
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9635439
master
shichen.fsc yingda.chen 3 years ago
parent
commit
c663dd8cf6
5 changed files with 194 additions and 125 deletions
  1. +68
    -31
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  2. +30
    -23
      modelscope/preprocessors/kws.py
  3. +1
    -1
      requirements/audio.txt
  4. +0
    -5
      tests/pipelines/test_automatic_speech_recognition.py
  5. +95
    -65
      tests/pipelines/test_key_word_spotting.py

+ 68
- 31
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -30,7 +30,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
""" """
super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)


def __call__(self, wav_path: Union[List[str], str],
def __call__(self, audio_in: Union[List[str], str, bytes],
**kwargs) -> Dict[str, Any]: **kwargs) -> Dict[str, Any]:
if 'keywords' in kwargs.keys(): if 'keywords' in kwargs.keys():
self.keywords = kwargs['keywords'] self.keywords = kwargs['keywords']
@@ -40,7 +40,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
if self.preprocessor is None: if self.preprocessor is None:
self.preprocessor = WavToLists() self.preprocessor = WavToLists()


output = self.preprocessor.forward(self.model.forward(), wav_path)
output = self.preprocessor.forward(self.model.forward(), audio_in)
output = self.forward(output) output = self.forward(output)
rst = self.postprocess(output) rst = self.postprocess(output)
return rst return rst
@@ -49,7 +49,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
"""Decoding """Decoding
""" """


logger.info(f"Decoding with {inputs['kws_set']} mode ...")
logger.info(f"Decoding with {inputs['kws_type']} mode ...")


# will generate kws result # will generate kws result
out = self.run_with_kwsbp(inputs) out = self.run_with_kwsbp(inputs)
@@ -80,60 +80,97 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
pos_kws_list = inputs['pos_kws_list'] pos_kws_list = inputs['pos_kws_list']
if 'neg_kws_list' in inputs: if 'neg_kws_list' in inputs:
neg_kws_list = inputs['neg_kws_list'] neg_kws_list = inputs['neg_kws_list']

rst_dict = kws_util.common.parsing_kws_result( rst_dict = kws_util.common.parsing_kws_result(
kws_type=inputs['kws_set'],
kws_type=inputs['kws_type'],
pos_list=pos_kws_list, pos_list=pos_kws_list,
neg_list=neg_kws_list) neg_list=neg_kws_list)


return rst_dict return rst_dict


def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
cmd = {
'sys_dir': inputs['model_workspace'],
'cfg_file': inputs['cfg_file_path'],
'sample_rate': inputs['sample_rate'],
'keyword_custom': ''
}

import kwsbp import kwsbp
import kws_util.common import kws_util.common
kws_inference = kwsbp.KwsbpEngine() kws_inference = kwsbp.KwsbpEngine()


# setting customized keywords
cmd['customized_keywords'] = kws_util.common.generate_customized_keywords(
self.keywords)
cmd = {
'sys_dir':
inputs['model_workspace'],
'cfg_file':
inputs['cfg_file_path'],
'sample_rate':
inputs['sample_rate'],
'keyword_custom':
'',
'pcm_data':
None,
'pcm_data_len':
0,
'list_flag':
True,
# setting customized keywords
'customized_keywords':
kws_util.common.generate_customized_keywords(self.keywords)
}

if inputs['kws_type'] == 'pcm':
cmd['pcm_data'] = inputs['pos_data']
cmd['pcm_data_len'] = len(inputs['pos_data'])
cmd['list_flag'] = False


if inputs['kws_set'] == 'roc':
if inputs['kws_type'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join( inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json') inputs['model_workspace'], 'keywords_roc.json')


if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
if inputs['kws_type'] in ['wav', 'pcm', 'pos_testsets', 'roc']:
cmd['wave_scp'] = inputs['pos_wav_list'] cmd['wave_scp'] = inputs['pos_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['pos_num_thread'] cmd['num_thread'] = inputs['pos_num_thread']


# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
if hasattr(kws_inference, 'inference_new'):
# run and get inference result
result = kws_inference.inference_new(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['pcm_data'],
cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'],
cmd['list_flag'])
else:
# in order to support kwsbp-0.0.1
result = kws_inference.inference(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['sample_rate'],
cmd['num_thread'])

pos_result = json.loads(result) pos_result = json.loads(result)
inputs['pos_kws_list'] = pos_result['kws_list'] inputs['pos_kws_list'] = pos_result['kws_list']


if inputs['kws_set'] in ['neg_testsets', 'roc']:
if inputs['kws_type'] in ['neg_testsets', 'roc']:
cmd['wave_scp'] = inputs['neg_wav_list'] cmd['wave_scp'] = inputs['neg_wav_list']
cmd['keyword_grammar_path'] = inputs['keyword_grammar_path'] cmd['keyword_grammar_path'] = inputs['keyword_grammar_path']
cmd['num_thread'] = inputs['neg_num_thread'] cmd['num_thread'] = inputs['neg_num_thread']


# run and get inference result
result = kws_inference.inference(cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']),
cmd['sample_rate'],
cmd['num_thread'])
if hasattr(kws_inference, 'inference_new'):
# run and get inference result
result = kws_inference.inference_new(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['pcm_data'],
cmd['pcm_data_len'], cmd['sample_rate'], cmd['num_thread'],
cmd['list_flag'])
else:
# in order to support kwsbp-0.0.1
result = kws_inference.inference(
cmd['sys_dir'], cmd['cfg_file'],
cmd['keyword_grammar_path'],
str(json.dumps(cmd['wave_scp'])),
str(cmd['customized_keywords']), cmd['sample_rate'],
cmd['num_thread'])

neg_result = json.loads(result) neg_result = json.loads(result)
inputs['neg_kws_list'] = neg_result['kws_list'] inputs['neg_kws_list'] = neg_result['kws_list']




+ 30
- 23
modelscope/preprocessors/kws.py View File

@@ -21,23 +21,26 @@ class WavToLists(Preprocessor):
def __init__(self): def __init__(self):
pass pass


def __call__(self, model: Model, wav_path: Union[List[str],
str]) -> Dict[str, Any]:
def __call__(self, model: Model, audio_in: Union[List[str], str,
bytes]) -> Dict[str, Any]:
"""Call functions to load model and wav. """Call functions to load model and wav.


Args: Args:
model (Model): model should be provided model (Model): model should be provided
wav_path (Union[List[str], str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
audio_in (Union[List[str], str, bytes]):
audio_in[0] is positive wav path, audio_in[1] is negative wav path;
audio_in (str) is positive wav path;
audio_in (bytes) is audio pcm data;
Returns: Returns:
Dict[str, Any]: the kws result Dict[str, Any]: the kws result
""" """


self.model = model self.model = model
out = self.forward(self.model.forward(), wav_path)
out = self.forward(self.model.forward(), audio_in)
return out return out


def forward(self, model: Dict[str, Any], def forward(self, model: Dict[str, Any],
wav_path: Union[List[str], str]) -> Dict[str, Any]:
audio_in: Union[List[str], str, bytes]) -> Dict[str, Any]:
assert len( assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty' model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists( assert os.path.exists(
@@ -45,22 +48,21 @@ class WavToLists(Preprocessor):


inputs = model.copy() inputs = model.copy()


wav_list = [None, None]
if isinstance(wav_path, str):
wav_list[0] = wav_path
else:
wav_list = wav_path

import kws_util.common import kws_util.common
kws_type = kws_util.common.type_checking(wav_list)
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'

inputs['kws_set'] = kws_type
if wav_list[0] is not None:
inputs['pos_wav_path'] = wav_list[0]
if wav_list[1] is not None:
inputs['neg_wav_path'] = wav_list[1]
kws_type = kws_util.common.type_checking(audio_in)
assert kws_type in [
'wav', 'pcm', 'pos_testsets', 'neg_testsets', 'roc'
], f'kws_type {kws_type} is invalid, please check audio data'

inputs['kws_type'] = kws_type
if kws_type == 'wav':
inputs['pos_wav_path'] = audio_in
elif kws_type == 'pcm':
inputs['pos_data'] = audio_in
if kws_type in ['pos_testsets', 'roc']:
inputs['pos_wav_path'] = audio_in[0]
if kws_type in ['neg_testsets', 'roc']:
inputs['neg_wav_path'] = audio_in[1]


out = self.read_config(inputs) out = self.read_config(inputs)
out = self.generate_wav_lists(out) out = self.generate_wav_lists(out)
@@ -93,7 +95,7 @@ class WavToLists(Preprocessor):
""" """
import kws_util.common import kws_util.common


if inputs['kws_set'] == 'wav':
if inputs['kws_type'] == 'wav':
wav_list = [] wav_list = []
wave_scp_content: str = inputs['pos_wav_path'] wave_scp_content: str = inputs['pos_wav_path']
wav_list.append(wave_scp_content) wav_list.append(wave_scp_content)
@@ -101,7 +103,12 @@ class WavToLists(Preprocessor):
inputs['pos_wav_count'] = 1 inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1 inputs['pos_num_thread'] = 1


if inputs['kws_set'] in ['pos_testsets', 'roc']:
if inputs['kws_type'] == 'pcm':
inputs['pos_wav_list'] = ['pcm_data']
inputs['pos_wav_count'] = 1
inputs['pos_num_thread'] = 1

if inputs['kws_type'] in ['pos_testsets', 'roc']:
# find all positive wave # find all positive wave
wav_list = [] wav_list = []
wav_dir = inputs['pos_wav_path'] wav_dir = inputs['pos_wav_path']
@@ -116,7 +123,7 @@ class WavToLists(Preprocessor):
else: else:
inputs['pos_num_thread'] = 128 inputs['pos_num_thread'] = 128


if inputs['kws_set'] in ['neg_testsets', 'roc']:
if inputs['kws_type'] in ['neg_testsets', 'roc']:
# find all negative wave # find all negative wave
wav_list = [] wav_list = []
wav_dir = inputs['neg_wav_path'] wav_dir = inputs['neg_wav_path']


+ 1
- 1
requirements/audio.txt View File

@@ -4,7 +4,7 @@ espnet>=202204
h5py h5py
inflect inflect
keras keras
kwsbp
kwsbp>=0.0.2
librosa librosa
lxml lxml
matplotlib matplotlib


+ 0
- 5
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -30,11 +30,6 @@ TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'




def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class AutomaticSpeechRecognitionTest(unittest.TestCase): class AutomaticSpeechRecognitionTest(unittest.TestCase):
action_info = { action_info = {
'test_run_with_wav_pytorch': { 'test_run_with_wav_pytorch': {


+ 95
- 65
tests/pipelines/test_key_word_spotting.py View File

@@ -5,8 +5,11 @@ import tarfile
import unittest import unittest
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union


import numpy as np
import requests import requests
import soundfile


from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import ColorCodes, Tasks from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@@ -27,12 +30,12 @@ NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/n
class KeyWordSpottingTest(unittest.TestCase): class KeyWordSpottingTest(unittest.TestCase):
action_info = { action_info = {
'test_run_with_wav': { 'test_run_with_wav': {
'checking_item': 'kws_list',
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '小云小云', 'checking_value': '小云小云',
'example': { 'example': {
'wav_count': 'wav_count':
1, 1,
'kws_set':
'kws_type':
'wav', 'wav',
'kws_list': [{ 'kws_list': [{
'keyword': '小云小云', 'keyword': '小云小云',
@@ -42,13 +45,29 @@ class KeyWordSpottingTest(unittest.TestCase):
}] }]
} }
}, },
'test_run_with_pcm': {
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_type':
'pcm',
'kws_list': [{
'keyword': '小云小云',
'offset': 5.76,
'length': 9.132938,
'confidence': 0.990368
}]
}
},
'test_run_with_wav_by_customized_keywords': { 'test_run_with_wav_by_customized_keywords': {
'checking_item': 'kws_list',
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '播放音乐', 'checking_value': '播放音乐',
'example': { 'example': {
'wav_count': 'wav_count':
1, 1,
'kws_set':
'kws_type':
'wav', 'wav',
'kws_list': [{ 'kws_list': [{
'keyword': '播放音乐', 'keyword': '播放音乐',
@@ -59,10 +78,10 @@ class KeyWordSpottingTest(unittest.TestCase):
} }
}, },
'test_run_with_pos_testsets': { 'test_run_with_pos_testsets': {
'checking_item': 'recall',
'checking_item': ['recall'],
'example': { 'example': {
'wav_count': 450, 'wav_count': 450,
'kws_set': 'pos_testsets',
'kws_type': 'pos_testsets',
'wav_time': 3013.75925, 'wav_time': 3013.75925,
'keywords': ['小云小云'], 'keywords': ['小云小云'],
'recall': 0.953333, 'recall': 0.953333,
@@ -72,11 +91,11 @@ class KeyWordSpottingTest(unittest.TestCase):
} }
}, },
'test_run_with_neg_testsets': { 'test_run_with_neg_testsets': {
'checking_item': 'fa_rate',
'checking_item': ['fa_rate'],
'example': { 'example': {
'wav_count': 'wav_count':
751, 751,
'kws_set':
'kws_type':
'neg_testsets', 'neg_testsets',
'wav_time': 'wav_time':
3572.180813, 3572.180813,
@@ -98,10 +117,10 @@ class KeyWordSpottingTest(unittest.TestCase):
} }
}, },
'test_run_with_roc': { 'test_run_with_roc': {
'checking_item': 'keywords',
'checking_item': ['keywords', 0],
'checking_value': '小云小云', 'checking_value': '小云小云',
'example': { 'example': {
'kws_set':
'kws_type':
'roc', 'roc',
'keywords': ['小云小云'], 'keywords': ['小云小云'],
'小云小云': [{ '小云小云': [{
@@ -129,21 +148,20 @@ class KeyWordSpottingTest(unittest.TestCase):


def tearDown(self) -> None: def tearDown(self) -> None:
# remove workspace dir (.tmp) # remove workspace dir (.tmp)
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace, ignore_errors=True)
shutil.rmtree(self.workspace, ignore_errors=True)


def run_pipeline(self, def run_pipeline(self,
model_id: str, model_id: str,
wav_path: Union[List[str], str],
audio_in: Union[List[str], str, bytes],
keywords: List[str] = None) -> Dict[str, Any]: keywords: List[str] = None) -> Dict[str, Any]:
kwsbp_16k_pipline = pipeline( kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id) task=Tasks.auto_speech_recognition, model=model_id)


kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords)
kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)


return kws_result return kws_result


def print_error(self, functions: str, result: Dict[str, Any]) -> None:
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END) + ColorCodes.END)
logger.error(ColorCodes.MAGENTA + functions logger.error(ColorCodes.MAGENTA + functions
@@ -153,49 +171,61 @@ class KeyWordSpottingTest(unittest.TestCase):


raise ValueError('kws result is mismatched') raise ValueError('kws result is mismatched')


def check_and_print_result(self, functions: str,
result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
checking_item = result[self.action_info[functions]
['checking_item']]
if functions == 'test_run_with_roc':
if checking_item[0] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

elif functions == 'test_run_with_wav_by_customized_keywords':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:'
+ str(threshold) + ' recall:' + str(recall)
+ ' fa_per_hour:' + str(fa_per_hour)
+ ColorCodes.END)
else:
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
result_item = result
check_list = self.action_info[functions]['checking_item']
for check_item in check_list:
result_item = result_item[check_item]
if result_item is None or result_item == 'None':
self.log_error(functions, result)

if self.action_info[functions].__contains__('checking_value'):
check_value = self.action_info[functions]['checking_value']
if result_item != check_value:
self.log_error(functions, result)

logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold)
+ ' recall:' + str(recall) + ' fa_per_hour:'
+ str(fa_per_hour) + ColorCodes.END)
else: else:
self.print_error(functions, result)
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)

def wav2bytes(self, wav_file) -> bytes:
audio, fs = soundfile.read(wav_file)

# float32 -> int16
audio = np.asarray(audio)
dtype = np.dtype('int16')
i = np.iinfo(dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)

# int16(PCM_16) -> byte
audio = audio.tobytes()
return audio


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self): def test_run_with_wav(self):
kws_result = self.run_pipeline( kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=POS_WAV_FILE)
self.check_and_print_result('test_run_with_wav', kws_result)
model_id=self.model_id, audio_in=POS_WAV_FILE)
self.check_result('test_run_with_wav', kws_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm(self):
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))

kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
self.check_result('test_run_with_pcm', kws_result)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self): def test_run_with_wav_by_customized_keywords(self):
@@ -203,32 +233,32 @@ class KeyWordSpottingTest(unittest.TestCase):


kws_result = self.run_pipeline( kws_result = self.run_pipeline(
model_id=self.model_id, model_id=self.model_id,
wav_path=BOFANGYINYUE_WAV_FILE,
audio_in=BOFANGYINYUE_WAV_FILE,
keywords=keywords) keywords=keywords)
self.check_and_print_result('test_run_with_wav_by_customized_keywords',
kws_result)
self.check_result('test_run_with_wav_by_customized_keywords',
kws_result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self): def test_run_with_pos_testsets(self):
wav_file_path = download_and_untar( wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace) self.workspace)
wav_path = [wav_file_path, None]
audio_list = [wav_file_path, None]


kws_result = self.run_pipeline( kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_pos_testsets', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_pos_testsets', kws_result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self): def test_run_with_neg_testsets(self):
wav_file_path = download_and_untar( wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace) self.workspace)
wav_path = [None, wav_file_path]
audio_list = [None, wav_file_path]


kws_result = self.run_pipeline( kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_neg_testsets', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_neg_testsets', kws_result)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self): def test_run_with_roc(self):
@@ -238,11 +268,11 @@ class KeyWordSpottingTest(unittest.TestCase):
neg_file_path = download_and_untar( neg_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace) self.workspace)
wav_path = [pos_file_path, neg_file_path]
audio_list = [pos_file_path, neg_file_path]


kws_result = self.run_pipeline( kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_roc', kws_result)
model_id=self.model_id, audio_in=audio_list)
self.check_result('test_run_with_roc', kws_result)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save