Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9122842 * [Add] add KWS code * [Update] check code linters and formatter * [Update] update kws code * Merge branch 'master' into dev/kws * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * feat: Fix confilct, auto commit by WebIDE * feat: Fix confilct, auto commit by WebIDE * Merge branch 'master' into dev/kws * [Update] refactor kws code * [Update] refactor kws code * [Update] refactor kws code, bug fix * [Update] refactor kws code, bug fixmaster
| @@ -21,6 +21,7 @@ class Models(object): | |||||
| sambert_hifi_16k = 'sambert-hifi-16k' | sambert_hifi_16k = 'sambert-hifi-16k' | ||||
| generic_tts_frontend = 'generic-tts-frontend' | generic_tts_frontend = 'generic-tts-frontend' | ||||
| hifigan16k = 'hifigan16k' | hifigan16k = 'hifigan16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | |||||
| # multi-modal models | # multi-modal models | ||||
| ofa = 'ofa' | ofa = 'ofa' | ||||
| @@ -53,6 +54,7 @@ class Pipelines(object): | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | ||||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | ||||
| kws_kwsbp = 'kws-kwsbp' | |||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_caption = 'image-caption' | image_caption = 'image-caption' | ||||
| @@ -94,6 +96,7 @@ class Preprocessors(object): | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| text_to_tacotron_symbols = 'text-to-tacotron-symbols' | text_to_tacotron_symbols = 'text-to-tacotron-symbols' | ||||
| wav_to_lists = 'wav-to-lists' | |||||
| # multi-modal | # multi-modal | ||||
| ofa_image_caption = 'ofa-image-caption' | ofa_image_caption = 'ofa-image-caption' | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .audio.kws import GenericKeyWordSpotting | |||||
| from .audio.tts.am import SambertNetHifi16k | from .audio.tts.am import SambertNetHifi16k | ||||
| from .audio.tts.vocoder import Hifigan16k | from .audio.tts.vocoder import Hifigan16k | ||||
| from .base import Model | from .base import Model | ||||
| @@ -0,0 +1 @@ | |||||
| from .generic_key_word_spotting import * # noqa F403 | |||||
| @@ -0,0 +1,30 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['GenericKeyWordSpotting'] | |||||
| @MODELS.register_module(Tasks.key_word_spotting, module_name=Models.kws_kwsbp) | |||||
| class GenericKeyWordSpotting(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the info of model. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| self.model_cfg = { | |||||
| 'model_workspace': model_dir, | |||||
| 'config_path': os.path.join(model_dir, 'config.yaml') | |||||
| } | |||||
| def forward(self) -> Dict[str, Any]: | |||||
| """return the info of the model | |||||
| """ | |||||
| return self.model_cfg | |||||
| @@ -1,2 +1,3 @@ | |||||
| from .kws_kwsbp_pipeline import * # noqa F403 | |||||
| from .linear_aec_pipeline import LinearAECPipeline | from .linear_aec_pipeline import LinearAECPipeline | ||||
| from .text_to_speech_pipeline import * # noqa F403 | from .text_to_speech_pipeline import * # noqa F403 | ||||
| @@ -0,0 +1,449 @@ | |||||
| import io | |||||
| import os | |||||
| import shutil | |||||
| import stat | |||||
| import subprocess | |||||
| from typing import Any, Dict, List | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import WavToLists | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['KeyWordSpottingKwsbpPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.key_word_spotting, module_name=Pipelines.kws_kwsbp) | |||||
| class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| """KWS Pipeline - key word spotting decoding | |||||
| """ | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model: Model = None, | |||||
| preprocessor: WavToLists = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| """ | |||||
| super().__init__( | |||||
| config_file=config_file, | |||||
| model=model, | |||||
| preprocessor=preprocessor, | |||||
| **kwargs) | |||||
| assert model is not None, 'kws model should be provided' | |||||
| assert preprocessor is not None, 'preprocessor is none' | |||||
| self._preprocessor = preprocessor | |||||
| self._model = model | |||||
| def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: | |||||
| assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', | |||||
| 'roc'], f'kws_type {kws_type} is invalid' | |||||
| output = self._preprocessor.forward(self._model.forward(), kws_type, | |||||
| wav_path) | |||||
| output = self.forward(output) | |||||
| rst = self.postprocess(output) | |||||
| return rst | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """Decoding | |||||
| """ | |||||
| # will generate kws result into dump/dump.JOB.log | |||||
| out = self._run_with_kwsbp(inputs) | |||||
| return out | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the kws results | |||||
| """ | |||||
| pos_result_json = {} | |||||
| neg_result_json = {} | |||||
| if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||||
| self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) | |||||
| if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
| self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) | |||||
| """ | |||||
| result_json format example: | |||||
| { | |||||
| "wav_count": 450, | |||||
| "keywords": ["小云小云"], | |||||
| "wav_time": 3560.999999, | |||||
| "detected": [ | |||||
| { | |||||
| "xxx.wav": { | |||||
| "confidence": "0.990368", | |||||
| "keyword": "小云小云" | |||||
| } | |||||
| }, | |||||
| { | |||||
| "yyy.wav": { | |||||
| "confidence": "0.990368", | |||||
| "keyword": "小云小云" | |||||
| } | |||||
| }, | |||||
| ...... | |||||
| ], | |||||
| "detected_count": 429, | |||||
| "rejected_count": 21, | |||||
| "rejected": [ | |||||
| "yyy.wav", | |||||
| "zzz.wav", | |||||
| ...... | |||||
| ] | |||||
| } | |||||
| """ | |||||
| rst_dict = {'kws_set': inputs['kws_set']} | |||||
| # parsing the result of wav | |||||
| if inputs['kws_set'] == 'wav': | |||||
| rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||||
| 'pos_wav_count'] | |||||
| rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||||
| if pos_result_json['detected_count'] == 1: | |||||
| rst_dict['keywords'] = pos_result_json['keywords'] | |||||
| rst_dict['detected'] = True | |||||
| wav_file_name = os.path.basename(inputs['pos_wav_path']) | |||||
| rst_dict['confidence'] = float(pos_result_json['detected'][0] | |||||
| [wav_file_name]['confidence']) | |||||
| else: | |||||
| rst_dict['detected'] = False | |||||
| # parsing the result of pos_tests | |||||
| elif inputs['kws_set'] == 'pos_testsets': | |||||
| rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ | |||||
| 'pos_wav_count'] | |||||
| rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) | |||||
| if pos_result_json.__contains__('keywords'): | |||||
| rst_dict['keywords'] = pos_result_json['keywords'] | |||||
| rst_dict['recall'] = round( | |||||
| pos_result_json['detected_count'] / rst_dict['wav_count'], 6) | |||||
| if pos_result_json.__contains__('detected_count'): | |||||
| rst_dict['detected_count'] = pos_result_json['detected_count'] | |||||
| if pos_result_json.__contains__('rejected_count'): | |||||
| rst_dict['rejected_count'] = pos_result_json['rejected_count'] | |||||
| if pos_result_json.__contains__('rejected'): | |||||
| rst_dict['rejected'] = pos_result_json['rejected'] | |||||
| # parsing the result of neg_tests | |||||
| elif inputs['kws_set'] == 'neg_testsets': | |||||
| rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ | |||||
| 'neg_wav_count'] | |||||
| rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) | |||||
| if neg_result_json.__contains__('keywords'): | |||||
| rst_dict['keywords'] = neg_result_json['keywords'] | |||||
| rst_dict['fa_rate'] = 0.0 | |||||
| rst_dict['fa_per_hour'] = 0.0 | |||||
| if neg_result_json.__contains__('detected_count'): | |||||
| rst_dict['detected_count'] = neg_result_json['detected_count'] | |||||
| rst_dict['fa_rate'] = round( | |||||
| neg_result_json['detected_count'] / rst_dict['wav_count'], | |||||
| 6) | |||||
| if neg_result_json.__contains__('wav_time'): | |||||
| rst_dict['fa_per_hour'] = round( | |||||
| neg_result_json['detected_count'] | |||||
| / float(neg_result_json['wav_time'] / 3600), 6) | |||||
| if neg_result_json.__contains__('rejected_count'): | |||||
| rst_dict['rejected_count'] = neg_result_json['rejected_count'] | |||||
| if neg_result_json.__contains__('detected'): | |||||
| rst_dict['detected'] = neg_result_json['detected'] | |||||
| # parsing the result of roc | |||||
| elif inputs['kws_set'] == 'roc': | |||||
| threshold_start = 0.000 | |||||
| threshold_step = 0.001 | |||||
| threshold_end = 1.000 | |||||
| pos_keywords_list = [] | |||||
| neg_keywords_list = [] | |||||
| if pos_result_json.__contains__('keywords'): | |||||
| pos_keywords_list = pos_result_json['keywords'] | |||||
| if neg_result_json.__contains__('keywords'): | |||||
| neg_keywords_list = neg_result_json['keywords'] | |||||
| keywords_list = list(set(pos_keywords_list + neg_keywords_list)) | |||||
| pos_result_json['wav_count'] = inputs['pos_wav_count'] | |||||
| neg_result_json['wav_count'] = inputs['neg_wav_count'] | |||||
| if len(keywords_list) > 0: | |||||
| rst_dict['keywords'] = keywords_list | |||||
| for index in range(len(rst_dict['keywords'])): | |||||
| cur_keyword = rst_dict['keywords'][index] | |||||
| output_list = self._generate_roc_list( | |||||
| start=threshold_start, | |||||
| step=threshold_step, | |||||
| end=threshold_end, | |||||
| keyword=cur_keyword, | |||||
| pos_inputs=pos_result_json, | |||||
| neg_inputs=neg_result_json) | |||||
| rst_dict[cur_keyword] = output_list | |||||
| return rst_dict | |||||
| def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| if inputs['kws_set'] == 'roc': | |||||
| inputs['keyword_grammar_path'] = os.path.join( | |||||
| inputs['model_workspace'], 'keywords_roc.json') | |||||
| if inputs['kws_set'] == 'wav': | |||||
| dump_log_path: str = os.path.join(inputs['pos_dump_path'], | |||||
| 'dump.log') | |||||
| kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
| ' --sys-dir=' + inputs['model_workspace'] + \ | |||||
| ' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
| ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ | |||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| os.system(kws_cmd) | |||||
| if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||||
| data_dir: str = os.listdir(inputs['pos_data_path']) | |||||
| wav_list = [] | |||||
| for i in data_dir: | |||||
| suffix = os.path.splitext(os.path.basename(i))[1] | |||||
| if suffix == '.list': | |||||
| wav_list.append(os.path.join(inputs['pos_data_path'], i)) | |||||
| j: int = 0 | |||||
| process = [] | |||||
| while j < inputs['pos_num_thread']: | |||||
| wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( | |||||
| j) + '.list' | |||||
| dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( | |||||
| j) + '.log' | |||||
| kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
| ' --sys-dir=' + inputs['model_workspace'] + \ | |||||
| ' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
| ' --wave-scp=' + wav_list_path + \ | |||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| p = subprocess.Popen(kws_cmd, shell=True) | |||||
| process.append(p) | |||||
| j += 1 | |||||
| k: int = 0 | |||||
| while k < len(process): | |||||
| process[k].wait() | |||||
| k += 1 | |||||
| if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
| data_dir: str = os.listdir(inputs['neg_data_path']) | |||||
| wav_list = [] | |||||
| for i in data_dir: | |||||
| suffix = os.path.splitext(os.path.basename(i))[1] | |||||
| if suffix == '.list': | |||||
| wav_list.append(os.path.join(inputs['neg_data_path'], i)) | |||||
| j: int = 0 | |||||
| process = [] | |||||
| while j < inputs['neg_num_thread']: | |||||
| wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( | |||||
| j) + '.list' | |||||
| dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( | |||||
| j) + '.log' | |||||
| kws_cmd: str = inputs['kws_tool_path'] + \ | |||||
| ' --sys-dir=' + inputs['model_workspace'] + \ | |||||
| ' --cfg-file=' + inputs['cfg_file_path'] + \ | |||||
| ' --sample-rate=' + inputs['sample_rate'] + \ | |||||
| ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ | |||||
| ' --wave-scp=' + wav_list_path + \ | |||||
| ' --num-thread=1 > ' + dump_log_path + ' 2>&1' | |||||
| p = subprocess.Popen(kws_cmd, shell=True) | |||||
| process.append(p) | |||||
| j += 1 | |||||
| k: int = 0 | |||||
| while k < len(process): | |||||
| process[k].wait() | |||||
| k += 1 | |||||
| return inputs | |||||
| def _parse_dump_log(self, result_json: Dict[str, Any], | |||||
| dump_path: str) -> Dict[str, Any]: | |||||
| dump_dir = os.listdir(dump_path) | |||||
| for i in dump_dir: | |||||
| basename = os.path.splitext(os.path.basename(i))[0] | |||||
| # find dump.JOB.log | |||||
| if 'dump' in basename: | |||||
| with open( | |||||
| os.path.join(dump_path, i), mode='r', | |||||
| encoding='utf-8') as file: | |||||
| while 1: | |||||
| line = file.readline() | |||||
| if not line: | |||||
| break | |||||
| else: | |||||
| result_json = self._parse_result_log( | |||||
| line, result_json) | |||||
| def _parse_result_log(self, line: str, | |||||
| result_json: Dict[str, Any]) -> Dict[str, Any]: | |||||
| # valid info | |||||
| if '[rejected]' in line or '[detected]' in line: | |||||
| detected_count = 0 | |||||
| rejected_count = 0 | |||||
| if result_json.__contains__('detected_count'): | |||||
| detected_count = result_json['detected_count'] | |||||
| if result_json.__contains__('rejected_count'): | |||||
| rejected_count = result_json['rejected_count'] | |||||
| if '[detected]' in line: | |||||
| # [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, | |||||
| # kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, | |||||
| detected_count += 1 | |||||
| content_list = line.split(', ') | |||||
| file_name = os.path.basename(content_list[1].split(':')[1]) | |||||
| keyword = content_list[2].split(':')[1] | |||||
| confidence = content_list[3].split(':')[1] | |||||
| keywords_list = [] | |||||
| if result_json.__contains__('keywords'): | |||||
| keywords_list = result_json['keywords'] | |||||
| if keyword not in keywords_list: | |||||
| keywords_list.append(keyword) | |||||
| result_json['keywords'] = keywords_list | |||||
| keyword_item = {} | |||||
| keyword_item['confidence'] = confidence | |||||
| keyword_item['keyword'] = keyword | |||||
| item = {} | |||||
| item[file_name] = keyword_item | |||||
| detected_list = [] | |||||
| if result_json.__contains__('detected'): | |||||
| detected_list = result_json['detected'] | |||||
| detected_list.append(item) | |||||
| result_json['detected'] = detected_list | |||||
| elif '[rejected]' in line: | |||||
| # [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav | |||||
| rejected_count += 1 | |||||
| content_list = line.split(', ') | |||||
| file_name = os.path.basename(content_list[1].split(':')[1]) | |||||
| file_name = file_name.strip().replace('\n', | |||||
| '').replace('\r', '') | |||||
| rejected_list = [] | |||||
| if result_json.__contains__('rejected'): | |||||
| rejected_list = result_json['rejected'] | |||||
| rejected_list.append(file_name) | |||||
| result_json['rejected'] = rejected_list | |||||
| result_json['detected_count'] = detected_count | |||||
| result_json['rejected_count'] = rejected_count | |||||
| elif 'total_proc_time=' in line and 'wav_time=' in line: | |||||
| # eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 | |||||
| wav_total_time = 0 | |||||
| content_list = line.split('), ') | |||||
| if result_json.__contains__('wav_time'): | |||||
| wav_total_time = result_json['wav_time'] | |||||
| wav_time_str = content_list[1].split('=')[1] | |||||
| wav_time_str = wav_time_str.split('(')[0] | |||||
| wav_time = float(wav_time_str) | |||||
| wav_time = round(wav_time, 6) | |||||
| if isinstance(wav_time, float): | |||||
| wav_total_time += wav_time | |||||
| result_json['wav_time'] = wav_total_time | |||||
| return result_json | |||||
| def _generate_roc_list(self, start: float, step: float, end: float, | |||||
| keyword: str, pos_inputs: Dict[str, Any], | |||||
| neg_inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| pos_wav_count = pos_inputs['wav_count'] | |||||
| neg_wav_time = neg_inputs['wav_time'] | |||||
| det_lists = pos_inputs['detected'] | |||||
| fa_lists = neg_inputs['detected'] | |||||
| threshold_cur = start | |||||
| """ | |||||
| input det_lists dict | |||||
| [ | |||||
| { | |||||
| "xxx.wav": { | |||||
| "confidence": "0.990368", | |||||
| "keyword": "小云小云" | |||||
| } | |||||
| }, | |||||
| { | |||||
| "yyy.wav": { | |||||
| "confidence": "0.990368", | |||||
| "keyword": "小云小云" | |||||
| } | |||||
| }, | |||||
| ] | |||||
| output dict | |||||
| [ | |||||
| { | |||||
| "threshold": 0.000, | |||||
| "recall": 0.999888, | |||||
| "fa_per_hour": 1.999999 | |||||
| }, | |||||
| { | |||||
| "threshold": 0.001, | |||||
| "recall": 0.999888, | |||||
| "fa_per_hour": 1.999999 | |||||
| }, | |||||
| ] | |||||
| """ | |||||
| output = [] | |||||
| while threshold_cur <= end: | |||||
| det_count = 0 | |||||
| fa_count = 0 | |||||
| for index in range(len(det_lists)): | |||||
| det_item = det_lists[index] | |||||
| det_wav_item = det_item.get(next(iter(det_item))) | |||||
| if det_wav_item['keyword'] == keyword: | |||||
| confidence = float(det_wav_item['confidence']) | |||||
| if confidence >= threshold_cur: | |||||
| det_count += 1 | |||||
| for index in range(len(fa_lists)): | |||||
| fa_item = fa_lists[index] | |||||
| fa_wav_item = fa_item.get(next(iter(fa_item))) | |||||
| if fa_wav_item['keyword'] == keyword: | |||||
| confidence = float(fa_wav_item['confidence']) | |||||
| if confidence >= threshold_cur: | |||||
| fa_count += 1 | |||||
| output_item = { | |||||
| 'threshold': round(threshold_cur, 3), | |||||
| 'recall': round(float(det_count / pos_wav_count), 6), | |||||
| 'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) | |||||
| } | |||||
| output.append(output_item) | |||||
| threshold_cur += step | |||||
| return output | |||||
| @@ -5,6 +5,7 @@ from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS, build_preprocessor | from .builder import PREPROCESSORS, build_preprocessor | ||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .kws import WavToLists | |||||
| from .multi_modal import OfaImageCaptionPreprocessor | from .multi_modal import OfaImageCaptionPreprocessor | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .text_to_speech import * # noqa F403 | from .text_to_speech import * # noqa F403 | ||||
| @@ -0,0 +1,253 @@ | |||||
| import os | |||||
| import shutil | |||||
| import stat | |||||
| from pathlib import Path | |||||
| from typing import Any, Dict, List | |||||
| import yaml | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.utils.constant import Fields | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| __all__ = ['WavToLists'] | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.audio, module_name=Preprocessors.wav_to_lists) | |||||
| class WavToLists(Preprocessor): | |||||
| """generate audio lists file from wav | |||||
| Args: | |||||
| workspace (str): store temporarily kws intermedium and result | |||||
| """ | |||||
| def __init__(self, workspace: str = None): | |||||
| # the workspace path | |||||
| if len(workspace) == 0: | |||||
| self._workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| else: | |||||
| self._workspace = workspace | |||||
| if not os.path.exists(self._workspace): | |||||
| os.mkdir(self._workspace) | |||||
| def __call__(self, | |||||
| model: Model = None, | |||||
| kws_type: str = None, | |||||
| wav_path: List[str] = None) -> Dict[str, Any]: | |||||
| """Call functions to load model and wav. | |||||
| Args: | |||||
| model (Model): model should be provided | |||||
| kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc | |||||
| wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path | |||||
| Returns: | |||||
| Dict[str, Any]: the kws result | |||||
| """ | |||||
| assert model is not None, 'preprocess kws model should be provided' | |||||
| assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' | |||||
| ], f'preprocess kws_type {kws_type} is invalid' | |||||
| assert wav_path[0] is not None or wav_path[ | |||||
| 1] is not None, 'preprocess wav_path is invalid' | |||||
| self._model = model | |||||
| out = self.forward(self._model.forward(), kws_type, wav_path) | |||||
| return out | |||||
| def forward(self, model: Dict[str, Any], kws_type: str, | |||||
| wav_path: List[str]) -> Dict[str, Any]: | |||||
| assert len(kws_type) > 0, 'preprocess kws_type is empty' | |||||
| assert len( | |||||
| model['config_path']) > 0, 'preprocess model[config_path] is empty' | |||||
| assert os.path.exists( | |||||
| model['config_path']), 'model config.yaml is absent' | |||||
| inputs = model.copy() | |||||
| inputs['kws_set'] = kws_type | |||||
| inputs['workspace'] = self._workspace | |||||
| if wav_path[0] is not None: | |||||
| inputs['pos_wav_path'] = wav_path[0] | |||||
| if wav_path[1] is not None: | |||||
| inputs['neg_wav_path'] = wav_path[1] | |||||
| out = self._read_config(inputs) | |||||
| out = self._generate_wav_lists(out) | |||||
| return out | |||||
| def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """read and parse config.yaml to get all model files | |||||
| """ | |||||
| assert os.path.exists( | |||||
| inputs['config_path']), 'model config yaml file does not exist' | |||||
| config_file = open(inputs['config_path']) | |||||
| root = yaml.full_load(config_file) | |||||
| config_file.close() | |||||
| inputs['cfg_file'] = root['cfg_file'] | |||||
| inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'], | |||||
| root['cfg_file']) | |||||
| inputs['keyword_grammar'] = root['keyword_grammar'] | |||||
| inputs['keyword_grammar_path'] = os.path.join( | |||||
| inputs['model_workspace'], root['keyword_grammar']) | |||||
| inputs['sample_rate'] = str(root['sample_rate']) | |||||
| inputs['kws_tool'] = root['kws_tool'] | |||||
| if os.path.exists( | |||||
| os.path.join(inputs['workspace'], inputs['kws_tool'])): | |||||
| inputs['kws_tool_path'] = os.path.join(inputs['workspace'], | |||||
| inputs['kws_tool']) | |||||
| elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): | |||||
| inputs['kws_tool_path'] = os.path.join('/usr/bin', | |||||
| inputs['kws_tool']) | |||||
| elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): | |||||
| inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) | |||||
| assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' | |||||
| os.chmod(inputs['kws_tool_path'], | |||||
| stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) | |||||
| self._config_checking(inputs) | |||||
| return inputs | |||||
| def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """assemble wav lists | |||||
| """ | |||||
| if inputs['kws_set'] == 'wav': | |||||
| inputs['pos_num_thread'] = 1 | |||||
| wave_scp_content: str = inputs['pos_wav_path'] + '\n' | |||||
| with open(os.path.join(inputs['pos_data_path'], 'wave.list'), | |||||
| 'a') as f: | |||||
| f.write(wave_scp_content) | |||||
| inputs['pos_wav_count'] = 1 | |||||
| if inputs['kws_set'] in ['pos_testsets', 'roc']: | |||||
| # find all positive wave | |||||
| wav_list = [] | |||||
| wav_dir = inputs['pos_wav_path'] | |||||
| wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
| list_count: int = len(wav_list) | |||||
| inputs['pos_wav_count'] = list_count | |||||
| if list_count <= 128: | |||||
| inputs['pos_num_thread'] = list_count | |||||
| j: int = 0 | |||||
| while j < list_count: | |||||
| wave_scp_content: str = wav_list[j] + '\n' | |||||
| wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||||
| j) + '.list' | |||||
| with open(wav_list_path, 'a') as f: | |||||
| f.write(wave_scp_content) | |||||
| j += 1 | |||||
| else: | |||||
| inputs['pos_num_thread'] = 128 | |||||
| j: int = 0 | |||||
| k: int = 0 | |||||
| while j < list_count: | |||||
| wave_scp_content: str = wav_list[j] + '\n' | |||||
| wav_list_path = inputs['pos_data_path'] + '/wave.' + str( | |||||
| k) + '.list' | |||||
| with open(wav_list_path, 'a') as f: | |||||
| f.write(wave_scp_content) | |||||
| j += 1 | |||||
| k += 1 | |||||
| if k >= 128: | |||||
| k = 0 | |||||
| if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
| # find all negative wave | |||||
| wav_list = [] | |||||
| wav_dir = inputs['neg_wav_path'] | |||||
| wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) | |||||
| list_count: int = len(wav_list) | |||||
| inputs['neg_wav_count'] = list_count | |||||
| if list_count <= 128: | |||||
| inputs['neg_num_thread'] = list_count | |||||
| j: int = 0 | |||||
| while j < list_count: | |||||
| wave_scp_content: str = wav_list[j] + '\n' | |||||
| wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||||
| j) + '.list' | |||||
| with open(wav_list_path, 'a') as f: | |||||
| f.write(wave_scp_content) | |||||
| j += 1 | |||||
| else: | |||||
| inputs['neg_num_thread'] = 128 | |||||
| j: int = 0 | |||||
| k: int = 0 | |||||
| while j < list_count: | |||||
| wave_scp_content: str = wav_list[j] + '\n' | |||||
| wav_list_path = inputs['neg_data_path'] + '/wave.' + str( | |||||
| k) + '.list' | |||||
| with open(wav_list_path, 'a') as f: | |||||
| f.write(wave_scp_content) | |||||
| j += 1 | |||||
| k += 1 | |||||
| if k >= 128: | |||||
| k = 0 | |||||
| return inputs | |||||
| def _recursion_dir_all_wave(self, wav_list, | |||||
| dir_path: str) -> Dict[str, Any]: | |||||
| dir_files = os.listdir(dir_path) | |||||
| for file in dir_files: | |||||
| file_path = os.path.join(dir_path, file) | |||||
| if os.path.isfile(file_path): | |||||
| if file_path.endswith('.wav') or file_path.endswith('.WAV'): | |||||
| wav_list.append(file_path) | |||||
| elif os.path.isdir(file_path): | |||||
| self._recursion_dir_all_wave(wav_list, file_path) | |||||
| return wav_list | |||||
| def _config_checking(self, inputs: Dict[str, Any]): | |||||
| if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: | |||||
| inputs['pos_data_path'] = os.path.join(inputs['workspace'], | |||||
| 'pos_data') | |||||
| if not os.path.exists(inputs['pos_data_path']): | |||||
| os.mkdir(inputs['pos_data_path']) | |||||
| else: | |||||
| shutil.rmtree(inputs['pos_data_path']) | |||||
| os.mkdir(inputs['pos_data_path']) | |||||
| inputs['pos_dump_path'] = os.path.join(inputs['workspace'], | |||||
| 'pos_dump') | |||||
| if not os.path.exists(inputs['pos_dump_path']): | |||||
| os.mkdir(inputs['pos_dump_path']) | |||||
| else: | |||||
| shutil.rmtree(inputs['pos_dump_path']) | |||||
| os.mkdir(inputs['pos_dump_path']) | |||||
| if inputs['kws_set'] in ['neg_testsets', 'roc']: | |||||
| inputs['neg_data_path'] = os.path.join(inputs['workspace'], | |||||
| 'neg_data') | |||||
| if not os.path.exists(inputs['neg_data_path']): | |||||
| os.mkdir(inputs['neg_data_path']) | |||||
| else: | |||||
| shutil.rmtree(inputs['neg_data_path']) | |||||
| os.mkdir(inputs['neg_data_path']) | |||||
| inputs['neg_dump_path'] = os.path.join(inputs['workspace'], | |||||
| 'neg_dump') | |||||
| if not os.path.exists(inputs['neg_dump_path']): | |||||
| os.mkdir(inputs['neg_dump_path']) | |||||
| else: | |||||
| shutil.rmtree(inputs['neg_dump_path']) | |||||
| os.mkdir(inputs['neg_dump_path']) | |||||
| @@ -52,6 +52,7 @@ class Tasks(object): | |||||
| auto_speech_recognition = 'auto-speech-recognition' | auto_speech_recognition = 'auto-speech-recognition' | ||||
| text_to_speech = 'text-to-speech' | text_to_speech = 'text-to-speech' | ||||
| speech_signal_process = 'speech-signal-process' | speech_signal_process = 'speech-signal-process' | ||||
| key_word_spotting = 'key-word-spotting' | |||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_captioning = 'image-captioning' | image_captioning = 'image-captioning' | ||||
| @@ -0,0 +1,334 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tarfile | |||||
| import unittest | |||||
| import requests | |||||
| from modelscope.metainfo import Pipelines, Preprocessors | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.preprocessors import build_preprocessor | |||||
| from modelscope.utils.constant import Fields, InputFields, Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' | |||||
| POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' | |||||
| POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE | |||||
| POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | |||||
| POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | |||||
| NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' | |||||
| NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' | |||||
| def un_tar_gz(fname, dirs): | |||||
| t = tarfile.open(fname) | |||||
| t.extractall(path=dirs) | |||||
| class KeyWordSpottingTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' | |||||
| self.workspace = os.path.join(os.getcwd(), '.tmp') | |||||
| if not os.path.exists(self.workspace): | |||||
| os.mkdir(self.workspace) | |||||
| def tearDown(self) -> None: | |||||
| if os.path.exists(self.workspace): | |||||
| shutil.rmtree(self.workspace) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_wav(self): | |||||
| # wav, neg_testsets, pos_testsets, roc | |||||
| kws_set = 'wav' | |||||
| # downloading wav file | |||||
| wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) | |||||
| if not os.path.exists(wav_file_path): | |||||
| r = requests.get(POS_WAV_URL) | |||||
| with open(wav_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| # downloading kwsbp | |||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
| if not os.path.exists(kwsbp_file_path): | |||||
| r = requests.get(KWSBP_URL) | |||||
| with open(kwsbp_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.assertTrue(model is not None) | |||||
| cfg_preprocessor = dict( | |||||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||||
| self.assertTrue(preprocessor is not None) | |||||
| kwsbp_16k_pipline = pipeline( | |||||
| pipeline_name=Pipelines.kws_kwsbp, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||||
| kws_result = kwsbp_16k_pipline( | |||||
| kws_type=kws_set, wav_path=[wav_file_path, None]) | |||||
| self.assertTrue(kws_result.__contains__('detected')) | |||||
| """ | |||||
| kws result json format example: | |||||
| { | |||||
| 'wav_count': 1, | |||||
| 'kws_set': 'wav', | |||||
| 'wav_time': 9.132938, | |||||
| 'keywords': ['小云小云'], | |||||
| 'detected': True, | |||||
| 'confidence': 0.990368 | |||||
| } | |||||
| """ | |||||
| if kws_result.__contains__('keywords'): | |||||
| print('test_run_with_wav keywords: ', kws_result['keywords']) | |||||
| print('test_run_with_wav detected result: ', kws_result['detected']) | |||||
| print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_pos_testsets(self): | |||||
| # wav, neg_testsets, pos_testsets, roc | |||||
| kws_set = 'pos_testsets' | |||||
| # downloading pos_testsets file | |||||
| testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(POS_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(POS_TESTSETS_FILE))[0] | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(testsets_dir_name))[0] | |||||
| # wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/ | |||||
| wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
| # untar the pos_testsets file | |||||
| if not os.path.exists(wav_file_path): | |||||
| un_tar_gz(testsets_file_path, self.workspace) | |||||
| # downloading kwsbp -- a kws batch processing tool | |||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
| if not os.path.exists(kwsbp_file_path): | |||||
| r = requests.get(KWSBP_URL) | |||||
| with open(kwsbp_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.assertTrue(model is not None) | |||||
| cfg_preprocessor = dict( | |||||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||||
| self.assertTrue(preprocessor is not None) | |||||
| kwsbp_16k_pipline = pipeline( | |||||
| pipeline_name=Pipelines.kws_kwsbp, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||||
| kws_result = kwsbp_16k_pipline( | |||||
| kws_type=kws_set, wav_path=[wav_file_path, None]) | |||||
| self.assertTrue(kws_result.__contains__('recall')) | |||||
| """ | |||||
| kws result json format example: | |||||
| { | |||||
| 'wav_count': 450, | |||||
| 'kws_set': 'pos_testsets', | |||||
| 'wav_time': 3013.759254, | |||||
| 'keywords': ["小云小云"], | |||||
| 'recall': 0.953333, | |||||
| 'detected_count': 429, | |||||
| 'rejected_count': 21, | |||||
| 'rejected': [ | |||||
| 'yyy.wav', | |||||
| 'zzz.wav', | |||||
| ...... | |||||
| ] | |||||
| } | |||||
| """ | |||||
| if kws_result.__contains__('keywords'): | |||||
| print('test_run_with_pos_testsets keywords: ', | |||||
| kws_result['keywords']) | |||||
| print('test_run_with_pos_testsets recall: ', kws_result['recall']) | |||||
| print('test_run_with_pos_testsets wave time(seconds): ', | |||||
| kws_result['wav_time']) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_neg_testsets(self): | |||||
| # wav, neg_testsets, pos_testsets, roc | |||||
| kws_set = 'neg_testsets' | |||||
| # downloading neg_testsets file | |||||
| testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(NEG_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(NEG_TESTSETS_FILE))[0] | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(testsets_dir_name))[0] | |||||
| # wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/ | |||||
| wav_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
| # untar the neg_testsets file | |||||
| if not os.path.exists(wav_file_path): | |||||
| un_tar_gz(testsets_file_path, self.workspace) | |||||
| # downloading kwsbp -- a kws batch processing tool | |||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
| if not os.path.exists(kwsbp_file_path): | |||||
| r = requests.get(KWSBP_URL) | |||||
| with open(kwsbp_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.assertTrue(model is not None) | |||||
| cfg_preprocessor = dict( | |||||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||||
| self.assertTrue(preprocessor is not None) | |||||
| kwsbp_16k_pipline = pipeline( | |||||
| pipeline_name=Pipelines.kws_kwsbp, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||||
| kws_result = kwsbp_16k_pipline( | |||||
| kws_type=kws_set, wav_path=[None, wav_file_path]) | |||||
| self.assertTrue(kws_result.__contains__('fa_rate')) | |||||
| """ | |||||
| kws result json format example: | |||||
| { | |||||
| 'wav_count': 751, | |||||
| 'kws_set': 'neg_testsets', | |||||
| 'wav_time': 3572.180812, | |||||
| 'keywords': ['小云小云'], | |||||
| 'fa_rate': 0.001332, | |||||
| 'fa_per_hour': 1.007788, | |||||
| 'detected_count': 1, | |||||
| 'rejected_count': 750, | |||||
| 'detected': [ | |||||
| { | |||||
| '6.wav': { | |||||
| 'confidence': '0.321170' | |||||
| } | |||||
| } | |||||
| ] | |||||
| } | |||||
| """ | |||||
| if kws_result.__contains__('keywords'): | |||||
| print('test_run_with_neg_testsets keywords: ', | |||||
| kws_result['keywords']) | |||||
| print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) | |||||
| print('test_run_with_neg_testsets fa per hour: ', | |||||
| kws_result['fa_per_hour']) | |||||
| print('test_run_with_neg_testsets wave time(seconds): ', | |||||
| kws_result['wav_time']) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_roc(self): | |||||
| # wav, neg_testsets, pos_testsets, roc | |||||
| kws_set = 'roc' | |||||
| # downloading neg_testsets file | |||||
| testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(NEG_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(NEG_TESTSETS_FILE))[0] | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(testsets_dir_name))[0] | |||||
| # neg_file_path = <workspace>/.tmp_roc/neg_testsets/ | |||||
| neg_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
| # untar the neg_testsets file | |||||
| if not os.path.exists(neg_file_path): | |||||
| un_tar_gz(testsets_file_path, self.workspace) | |||||
| # downloading pos_testsets file | |||||
| testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) | |||||
| if not os.path.exists(testsets_file_path): | |||||
| r = requests.get(POS_TESTSETS_URL) | |||||
| with open(testsets_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(POS_TESTSETS_FILE))[0] | |||||
| testsets_dir_name = os.path.splitext( | |||||
| os.path.basename(testsets_dir_name))[0] | |||||
| # pos_file_path = <workspace>/.tmp_roc/pos_testsets/ | |||||
| pos_file_path = os.path.join(self.workspace, testsets_dir_name) | |||||
| # untar the pos_testsets file | |||||
| if not os.path.exists(pos_file_path): | |||||
| un_tar_gz(testsets_file_path, self.workspace) | |||||
| # downloading kwsbp -- a kws batch processing tool | |||||
| kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') | |||||
| if not os.path.exists(kwsbp_file_path): | |||||
| r = requests.get(KWSBP_URL) | |||||
| with open(kwsbp_file_path, 'wb') as f: | |||||
| f.write(r.content) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| self.assertTrue(model is not None) | |||||
| cfg_preprocessor = dict( | |||||
| type=Preprocessors.wav_to_lists, workspace=self.workspace) | |||||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||||
| self.assertTrue(preprocessor is not None) | |||||
| kwsbp_16k_pipline = pipeline( | |||||
| pipeline_name=Pipelines.kws_kwsbp, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| self.assertTrue(kwsbp_16k_pipline is not None) | |||||
| kws_result = kwsbp_16k_pipline( | |||||
| kws_type=kws_set, wav_path=[pos_file_path, neg_file_path]) | |||||
| """ | |||||
| kws result json format example: | |||||
| { | |||||
| 'kws_set': 'roc', | |||||
| 'keywords': ['小云小云'], | |||||
| '小云小云': [ | |||||
| {'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||||
| {'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, | |||||
| ...... | |||||
| {'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} | |||||
| ] | |||||
| } | |||||
| """ | |||||
| if kws_result.__contains__('keywords'): | |||||
| find_keyword = kws_result['keywords'][0] | |||||
| print('test_run_with_roc keywords: ', find_keyword) | |||||
| keyword_list = kws_result[find_keyword] | |||||
| for item in iter(keyword_list): | |||||
| threshold: float = item['threshold'] | |||||
| recall: float = item['recall'] | |||||
| fa_per_hour: float = item['fa_per_hour'] | |||||
| print(' threshold:', threshold, ' recall:', recall, | |||||
| ' fa_per_hour:', fa_per_hour) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||