|
|
|
@@ -5,8 +5,11 @@ import tarfile |
|
|
|
import unittest |
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import requests |
|
|
|
import soundfile |
|
|
|
|
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.pipelines import pipeline |
|
|
|
from modelscope.utils.constant import ColorCodes, Tasks |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
@@ -27,12 +30,12 @@ NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/n |
|
|
|
class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
action_info = { |
|
|
|
'test_run_with_wav': { |
|
|
|
'checking_item': 'kws_list', |
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], |
|
|
|
'checking_value': '小云小云', |
|
|
|
'example': { |
|
|
|
'wav_count': |
|
|
|
1, |
|
|
|
'kws_set': |
|
|
|
'kws_type': |
|
|
|
'wav', |
|
|
|
'kws_list': [{ |
|
|
|
'keyword': '小云小云', |
|
|
|
@@ -42,13 +45,29 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
}] |
|
|
|
} |
|
|
|
}, |
|
|
|
'test_run_with_pcm': { |
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], |
|
|
|
'checking_value': '小云小云', |
|
|
|
'example': { |
|
|
|
'wav_count': |
|
|
|
1, |
|
|
|
'kws_type': |
|
|
|
'pcm', |
|
|
|
'kws_list': [{ |
|
|
|
'keyword': '小云小云', |
|
|
|
'offset': 5.76, |
|
|
|
'length': 9.132938, |
|
|
|
'confidence': 0.990368 |
|
|
|
}] |
|
|
|
} |
|
|
|
}, |
|
|
|
'test_run_with_wav_by_customized_keywords': { |
|
|
|
'checking_item': 'kws_list', |
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], |
|
|
|
'checking_value': '播放音乐', |
|
|
|
'example': { |
|
|
|
'wav_count': |
|
|
|
1, |
|
|
|
'kws_set': |
|
|
|
'kws_type': |
|
|
|
'wav', |
|
|
|
'kws_list': [{ |
|
|
|
'keyword': '播放音乐', |
|
|
|
@@ -59,10 +78,10 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
} |
|
|
|
}, |
|
|
|
'test_run_with_pos_testsets': { |
|
|
|
'checking_item': 'recall', |
|
|
|
'checking_item': ['recall'], |
|
|
|
'example': { |
|
|
|
'wav_count': 450, |
|
|
|
'kws_set': 'pos_testsets', |
|
|
|
'kws_type': 'pos_testsets', |
|
|
|
'wav_time': 3013.75925, |
|
|
|
'keywords': ['小云小云'], |
|
|
|
'recall': 0.953333, |
|
|
|
@@ -72,11 +91,11 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
} |
|
|
|
}, |
|
|
|
'test_run_with_neg_testsets': { |
|
|
|
'checking_item': 'fa_rate', |
|
|
|
'checking_item': ['fa_rate'], |
|
|
|
'example': { |
|
|
|
'wav_count': |
|
|
|
751, |
|
|
|
'kws_set': |
|
|
|
'kws_type': |
|
|
|
'neg_testsets', |
|
|
|
'wav_time': |
|
|
|
3572.180813, |
|
|
|
@@ -98,10 +117,10 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
} |
|
|
|
}, |
|
|
|
'test_run_with_roc': { |
|
|
|
'checking_item': 'keywords', |
|
|
|
'checking_item': ['keywords', 0], |
|
|
|
'checking_value': '小云小云', |
|
|
|
'example': { |
|
|
|
'kws_set': |
|
|
|
'kws_type': |
|
|
|
'roc', |
|
|
|
'keywords': ['小云小云'], |
|
|
|
'小云小云': [{ |
|
|
|
@@ -129,21 +148,20 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
|
|
|
|
def tearDown(self) -> None: |
|
|
|
# remove workspace dir (.tmp) |
|
|
|
if os.path.exists(self.workspace): |
|
|
|
shutil.rmtree(self.workspace, ignore_errors=True) |
|
|
|
shutil.rmtree(self.workspace, ignore_errors=True) |
|
|
|
|
|
|
|
def run_pipeline(self, |
|
|
|
model_id: str, |
|
|
|
wav_path: Union[List[str], str], |
|
|
|
audio_in: Union[List[str], str, bytes], |
|
|
|
keywords: List[str] = None) -> Dict[str, Any]: |
|
|
|
kwsbp_16k_pipline = pipeline( |
|
|
|
task=Tasks.auto_speech_recognition, model=model_id) |
|
|
|
|
|
|
|
kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords) |
|
|
|
kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords) |
|
|
|
|
|
|
|
return kws_result |
|
|
|
|
|
|
|
def print_error(self, functions: str, result: Dict[str, Any]) -> None: |
|
|
|
def log_error(self, functions: str, result: Dict[str, Any]) -> None: |
|
|
|
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.' |
|
|
|
+ ColorCodes.END) |
|
|
|
logger.error(ColorCodes.MAGENTA + functions |
|
|
|
@@ -153,49 +171,61 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
|
|
|
|
raise ValueError('kws result is mismatched') |
|
|
|
|
|
|
|
def check_and_print_result(self, functions: str, |
|
|
|
result: Dict[str, Any]) -> None: |
|
|
|
if result.__contains__(self.action_info[functions]['checking_item']): |
|
|
|
checking_item = result[self.action_info[functions] |
|
|
|
['checking_item']] |
|
|
|
if functions == 'test_run_with_roc': |
|
|
|
if checking_item[0] != self.action_info[functions][ |
|
|
|
'checking_value']: |
|
|
|
self.print_error(functions, result) |
|
|
|
|
|
|
|
elif functions == 'test_run_with_wav': |
|
|
|
if checking_item[0]['keyword'] != self.action_info[functions][ |
|
|
|
'checking_value']: |
|
|
|
self.print_error(functions, result) |
|
|
|
|
|
|
|
elif functions == 'test_run_with_wav_by_customized_keywords': |
|
|
|
if checking_item[0]['keyword'] != self.action_info[functions][ |
|
|
|
'checking_value']: |
|
|
|
self.print_error(functions, result) |
|
|
|
|
|
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' |
|
|
|
+ ColorCodes.END) |
|
|
|
if functions == 'test_run_with_roc': |
|
|
|
find_keyword = result['keywords'][0] |
|
|
|
keyword_list = result[find_keyword] |
|
|
|
for item in iter(keyword_list): |
|
|
|
threshold: float = item['threshold'] |
|
|
|
recall: float = item['recall'] |
|
|
|
fa_per_hour: float = item['fa_per_hour'] |
|
|
|
logger.info(ColorCodes.YELLOW + ' threshold:' |
|
|
|
+ str(threshold) + ' recall:' + str(recall) |
|
|
|
+ ' fa_per_hour:' + str(fa_per_hour) |
|
|
|
+ ColorCodes.END) |
|
|
|
else: |
|
|
|
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) |
|
|
|
def check_result(self, functions: str, result: Dict[str, Any]) -> None: |
|
|
|
result_item = result |
|
|
|
check_list = self.action_info[functions]['checking_item'] |
|
|
|
for check_item in check_list: |
|
|
|
result_item = result_item[check_item] |
|
|
|
if result_item is None or result_item == 'None': |
|
|
|
self.log_error(functions, result) |
|
|
|
|
|
|
|
if self.action_info[functions].__contains__('checking_value'): |
|
|
|
check_value = self.action_info[functions]['checking_value'] |
|
|
|
if result_item != check_value: |
|
|
|
self.log_error(functions, result) |
|
|
|
|
|
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.' |
|
|
|
+ ColorCodes.END) |
|
|
|
if functions == 'test_run_with_roc': |
|
|
|
find_keyword = result['keywords'][0] |
|
|
|
keyword_list = result[find_keyword] |
|
|
|
for item in iter(keyword_list): |
|
|
|
threshold: float = item['threshold'] |
|
|
|
recall: float = item['recall'] |
|
|
|
fa_per_hour: float = item['fa_per_hour'] |
|
|
|
logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold) |
|
|
|
+ ' recall:' + str(recall) + ' fa_per_hour:' |
|
|
|
+ str(fa_per_hour) + ColorCodes.END) |
|
|
|
else: |
|
|
|
self.print_error(functions, result) |
|
|
|
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END) |
|
|
|
|
|
|
|
def wav2bytes(self, wav_file) -> bytes: |
|
|
|
audio, fs = soundfile.read(wav_file) |
|
|
|
|
|
|
|
# float32 -> int16 |
|
|
|
audio = np.asarray(audio) |
|
|
|
dtype = np.dtype('int16') |
|
|
|
i = np.iinfo(dtype) |
|
|
|
abs_max = 2**(i.bits - 1) |
|
|
|
offset = i.min + abs_max |
|
|
|
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype) |
|
|
|
|
|
|
|
# int16(PCM_16) -> byte |
|
|
|
audio = audio.tobytes() |
|
|
|
return audio |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_wav(self): |
|
|
|
kws_result = self.run_pipeline( |
|
|
|
model_id=self.model_id, wav_path=POS_WAV_FILE) |
|
|
|
self.check_and_print_result('test_run_with_wav', kws_result) |
|
|
|
model_id=self.model_id, audio_in=POS_WAV_FILE) |
|
|
|
self.check_result('test_run_with_wav', kws_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_pcm(self): |
|
|
|
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE)) |
|
|
|
|
|
|
|
kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio) |
|
|
|
self.check_result('test_run_with_pcm', kws_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_run_with_wav_by_customized_keywords(self): |
|
|
|
@@ -203,32 +233,32 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
|
|
|
|
kws_result = self.run_pipeline( |
|
|
|
model_id=self.model_id, |
|
|
|
wav_path=BOFANGYINYUE_WAV_FILE, |
|
|
|
audio_in=BOFANGYINYUE_WAV_FILE, |
|
|
|
keywords=keywords) |
|
|
|
self.check_and_print_result('test_run_with_wav_by_customized_keywords', |
|
|
|
kws_result) |
|
|
|
self.check_result('test_run_with_wav_by_customized_keywords', |
|
|
|
kws_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
def test_run_with_pos_testsets(self): |
|
|
|
wav_file_path = download_and_untar( |
|
|
|
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL, |
|
|
|
self.workspace) |
|
|
|
wav_path = [wav_file_path, None] |
|
|
|
audio_list = [wav_file_path, None] |
|
|
|
|
|
|
|
kws_result = self.run_pipeline( |
|
|
|
model_id=self.model_id, wav_path=wav_path) |
|
|
|
self.check_and_print_result('test_run_with_pos_testsets', kws_result) |
|
|
|
model_id=self.model_id, audio_in=audio_list) |
|
|
|
self.check_result('test_run_with_pos_testsets', kws_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
def test_run_with_neg_testsets(self): |
|
|
|
wav_file_path = download_and_untar( |
|
|
|
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, |
|
|
|
self.workspace) |
|
|
|
wav_path = [None, wav_file_path] |
|
|
|
audio_list = [None, wav_file_path] |
|
|
|
|
|
|
|
kws_result = self.run_pipeline( |
|
|
|
model_id=self.model_id, wav_path=wav_path) |
|
|
|
self.check_and_print_result('test_run_with_neg_testsets', kws_result) |
|
|
|
model_id=self.model_id, audio_in=audio_list) |
|
|
|
self.check_result('test_run_with_neg_testsets', kws_result) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_run_with_roc(self): |
|
|
|
@@ -238,11 +268,11 @@ class KeyWordSpottingTest(unittest.TestCase): |
|
|
|
neg_file_path = download_and_untar( |
|
|
|
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL, |
|
|
|
self.workspace) |
|
|
|
wav_path = [pos_file_path, neg_file_path] |
|
|
|
audio_list = [pos_file_path, neg_file_path] |
|
|
|
|
|
|
|
kws_result = self.run_pipeline( |
|
|
|
model_id=self.model_id, wav_path=wav_path) |
|
|
|
self.check_and_print_result('test_run_with_roc', kws_result) |
|
|
|
model_id=self.model_id, audio_in=audio_list) |
|
|
|
self.check_result('test_run_with_roc', kws_result) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|