diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index af89cf33..680fe2e8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -21,6 +21,7 @@ class Models(object): sambert_hifi_16k = 'sambert-hifi-16k' generic_tts_frontend = 'generic-tts-frontend' hifigan16k = 'hifigan16k' + kws_kwsbp = 'kws-kwsbp' # multi-modal models ofa = 'ofa' @@ -53,6 +54,7 @@ class Pipelines(object): # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + kws_kwsbp = 'kws-kwsbp' # multi-modal tasks image_caption = 'image-caption' @@ -94,6 +96,7 @@ class Preprocessors(object): # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' text_to_tacotron_symbols = 'text-to-tacotron-symbols' + wav_to_lists = 'wav-to-lists' # multi-modal ofa_image_caption = 'ofa-image-caption' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 06380035..ebf81c32 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio.kws import GenericKeyWordSpotting from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k from .base import Model diff --git a/modelscope/models/audio/kws/__init__.py b/modelscope/models/audio/kws/__init__.py new file mode 100644 index 00000000..d7e163a9 --- /dev/null +++ b/modelscope/models/audio/kws/__init__.py @@ -0,0 +1 @@ +from .generic_key_word_spotting import * # noqa F403 diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py new file mode 100644 index 00000000..7a738d5b --- /dev/null +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -0,0 +1,30 @@ +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['GenericKeyWordSpotting'] + + +@MODELS.register_module(Tasks.key_word_spotting, module_name=Models.kws_kwsbp) +class GenericKeyWordSpotting(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + """ + + self.model_cfg = { + 'model_workspace': model_dir, + 'config_path': os.path.join(model_dir, 'config.yaml') + } + + def forward(self) -> Dict[str, Any]: + """return the info of the model + """ + return self.model_cfg diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index 20c7710a..87ccd49a 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -1,2 +1,3 @@ +from .kws_kwsbp_pipeline import * # noqa F403 from .linear_aec_pipeline import LinearAECPipeline from .text_to_speech_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py new file mode 100644 index 00000000..4a69976a --- /dev/null +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -0,0 +1,449 @@ +import io +import os +import shutil +import stat +import subprocess +from typing import Any, Dict, List + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToLists +from modelscope.utils.constant import Tasks + +__all__ = ['KeyWordSpottingKwsbpPipeline'] + + +@PIPELINES.register_module( + Tasks.key_word_spotting, module_name=Pipelines.kws_kwsbp) +class KeyWordSpottingKwsbpPipeline(Pipeline): + """KWS Pipeline - key word spotting decoding + """ + + def __init__(self, + config_file: str = None, + model: Model = None, + preprocessor: WavToLists = None, + **kwargs): + """use `model` and `preprocessor` to create a kws pipeline for prediction + """ + + super().__init__( + config_file=config_file, + model=model, + preprocessor=preprocessor, + **kwargs) + assert model is not None, 'kws model should be provided' + assert preprocessor is not None, 'preprocessor is none' + + self._preprocessor = preprocessor + self._model = model + + def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: + assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', + 'roc'], f'kws_type {kws_type} is invalid' + output = self._preprocessor.forward(self._model.forward(), kws_type, + wav_path) + output = self.forward(output) + rst = self.postprocess(output) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + + # will generate kws result into dump/dump.JOB.log + out = self._run_with_kwsbp(inputs) + + return out + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the kws results + """ + + pos_result_json = {} + neg_result_json = {} + + if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) + if inputs['kws_set'] in ['neg_testsets', 'roc']: + self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) + """ + result_json format example: + { + "wav_count": 450, + "keywords": ["小云小云"], + "wav_time": 3560.999999, + "detected": [ + { + "xxx.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + { + "yyy.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + ...... + ], + "detected_count": 429, + "rejected_count": 21, + "rejected": [ + "yyy.wav", + "zzz.wav", + ...... + ] + } + """ + + rst_dict = {'kws_set': inputs['kws_set']} + + # parsing the result of wav + if inputs['kws_set'] == 'wav': + rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ + 'pos_wav_count'] + rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) + if pos_result_json['detected_count'] == 1: + rst_dict['keywords'] = pos_result_json['keywords'] + rst_dict['detected'] = True + wav_file_name = os.path.basename(inputs['pos_wav_path']) + rst_dict['confidence'] = float(pos_result_json['detected'][0] + [wav_file_name]['confidence']) + else: + rst_dict['detected'] = False + + # parsing the result of pos_tests + elif inputs['kws_set'] == 'pos_testsets': + rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ + 'pos_wav_count'] + rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) + if pos_result_json.__contains__('keywords'): + rst_dict['keywords'] = pos_result_json['keywords'] + + rst_dict['recall'] = round( + pos_result_json['detected_count'] / rst_dict['wav_count'], 6) + + if pos_result_json.__contains__('detected_count'): + rst_dict['detected_count'] = pos_result_json['detected_count'] + if pos_result_json.__contains__('rejected_count'): + rst_dict['rejected_count'] = pos_result_json['rejected_count'] + if pos_result_json.__contains__('rejected'): + rst_dict['rejected'] = pos_result_json['rejected'] + + # parsing the result of neg_tests + elif inputs['kws_set'] == 'neg_testsets': + rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ + 'neg_wav_count'] + rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) + if neg_result_json.__contains__('keywords'): + rst_dict['keywords'] = neg_result_json['keywords'] + + rst_dict['fa_rate'] = 0.0 + rst_dict['fa_per_hour'] = 0.0 + + if neg_result_json.__contains__('detected_count'): + rst_dict['detected_count'] = neg_result_json['detected_count'] + rst_dict['fa_rate'] = round( + neg_result_json['detected_count'] / rst_dict['wav_count'], + 6) + if neg_result_json.__contains__('wav_time'): + rst_dict['fa_per_hour'] = round( + neg_result_json['detected_count'] + / float(neg_result_json['wav_time'] / 3600), 6) + + if neg_result_json.__contains__('rejected_count'): + rst_dict['rejected_count'] = neg_result_json['rejected_count'] + + if neg_result_json.__contains__('detected'): + rst_dict['detected'] = neg_result_json['detected'] + + # parsing the result of roc + elif inputs['kws_set'] == 'roc': + threshold_start = 0.000 + threshold_step = 0.001 + threshold_end = 1.000 + + pos_keywords_list = [] + neg_keywords_list = [] + if pos_result_json.__contains__('keywords'): + pos_keywords_list = pos_result_json['keywords'] + if neg_result_json.__contains__('keywords'): + neg_keywords_list = neg_result_json['keywords'] + + keywords_list = list(set(pos_keywords_list + neg_keywords_list)) + + pos_result_json['wav_count'] = inputs['pos_wav_count'] + neg_result_json['wav_count'] = inputs['neg_wav_count'] + + if len(keywords_list) > 0: + rst_dict['keywords'] = keywords_list + + for index in range(len(rst_dict['keywords'])): + cur_keyword = rst_dict['keywords'][index] + output_list = self._generate_roc_list( + start=threshold_start, + step=threshold_step, + end=threshold_end, + keyword=cur_keyword, + pos_inputs=pos_result_json, + neg_inputs=neg_result_json) + + rst_dict[cur_keyword] = output_list + + return rst_dict + + def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + if inputs['kws_set'] == 'roc': + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], 'keywords_roc.json') + + if inputs['kws_set'] == 'wav': + dump_log_path: str = os.path.join(inputs['pos_dump_path'], + 'dump.log') + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + os.system(kws_cmd) + + if inputs['kws_set'] in ['pos_testsets', 'roc']: + data_dir: str = os.listdir(inputs['pos_data_path']) + wav_list = [] + for i in data_dir: + suffix = os.path.splitext(os.path.basename(i))[1] + if suffix == '.list': + wav_list.append(os.path.join(inputs['pos_data_path'], i)) + + j: int = 0 + process = [] + while j < inputs['pos_num_thread']: + wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( + j) + '.list' + dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( + j) + '.log' + + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + wav_list_path + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + p = subprocess.Popen(kws_cmd, shell=True) + process.append(p) + j += 1 + + k: int = 0 + while k < len(process): + process[k].wait() + k += 1 + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + data_dir: str = os.listdir(inputs['neg_data_path']) + wav_list = [] + for i in data_dir: + suffix = os.path.splitext(os.path.basename(i))[1] + if suffix == '.list': + wav_list.append(os.path.join(inputs['neg_data_path'], i)) + + j: int = 0 + process = [] + while j < inputs['neg_num_thread']: + wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( + j) + '.list' + dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( + j) + '.log' + + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + wav_list_path + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + p = subprocess.Popen(kws_cmd, shell=True) + process.append(p) + j += 1 + + k: int = 0 + while k < len(process): + process[k].wait() + k += 1 + + return inputs + + def _parse_dump_log(self, result_json: Dict[str, Any], + dump_path: str) -> Dict[str, Any]: + dump_dir = os.listdir(dump_path) + for i in dump_dir: + basename = os.path.splitext(os.path.basename(i))[0] + # find dump.JOB.log + if 'dump' in basename: + with open( + os.path.join(dump_path, i), mode='r', + encoding='utf-8') as file: + while 1: + line = file.readline() + if not line: + break + else: + result_json = self._parse_result_log( + line, result_json) + + def _parse_result_log(self, line: str, + result_json: Dict[str, Any]) -> Dict[str, Any]: + # valid info + if '[rejected]' in line or '[detected]' in line: + detected_count = 0 + rejected_count = 0 + + if result_json.__contains__('detected_count'): + detected_count = result_json['detected_count'] + if result_json.__contains__('rejected_count'): + rejected_count = result_json['rejected_count'] + + if '[detected]' in line: + # [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, + # kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, + detected_count += 1 + content_list = line.split(', ') + file_name = os.path.basename(content_list[1].split(':')[1]) + keyword = content_list[2].split(':')[1] + confidence = content_list[3].split(':')[1] + + keywords_list = [] + if result_json.__contains__('keywords'): + keywords_list = result_json['keywords'] + + if keyword not in keywords_list: + keywords_list.append(keyword) + result_json['keywords'] = keywords_list + + keyword_item = {} + keyword_item['confidence'] = confidence + keyword_item['keyword'] = keyword + item = {} + item[file_name] = keyword_item + + detected_list = [] + if result_json.__contains__('detected'): + detected_list = result_json['detected'] + + detected_list.append(item) + result_json['detected'] = detected_list + + elif '[rejected]' in line: + # [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav + rejected_count += 1 + content_list = line.split(', ') + file_name = os.path.basename(content_list[1].split(':')[1]) + file_name = file_name.strip().replace('\n', + '').replace('\r', '') + + rejected_list = [] + if result_json.__contains__('rejected'): + rejected_list = result_json['rejected'] + + rejected_list.append(file_name) + result_json['rejected'] = rejected_list + + result_json['detected_count'] = detected_count + result_json['rejected_count'] = rejected_count + + elif 'total_proc_time=' in line and 'wav_time=' in line: + # eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 + wav_total_time = 0 + content_list = line.split('), ') + if result_json.__contains__('wav_time'): + wav_total_time = result_json['wav_time'] + + wav_time_str = content_list[1].split('=')[1] + wav_time_str = wav_time_str.split('(')[0] + wav_time = float(wav_time_str) + wav_time = round(wav_time, 6) + + if isinstance(wav_time, float): + wav_total_time += wav_time + + result_json['wav_time'] = wav_total_time + + return result_json + + def _generate_roc_list(self, start: float, step: float, end: float, + keyword: str, pos_inputs: Dict[str, Any], + neg_inputs: Dict[str, Any]) -> Dict[str, Any]: + pos_wav_count = pos_inputs['wav_count'] + neg_wav_time = neg_inputs['wav_time'] + det_lists = pos_inputs['detected'] + fa_lists = neg_inputs['detected'] + threshold_cur = start + """ + input det_lists dict + [ + { + "xxx.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + { + "yyy.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + ] + + output dict + [ + { + "threshold": 0.000, + "recall": 0.999888, + "fa_per_hour": 1.999999 + }, + { + "threshold": 0.001, + "recall": 0.999888, + "fa_per_hour": 1.999999 + }, + ] + """ + + output = [] + while threshold_cur <= end: + det_count = 0 + fa_count = 0 + for index in range(len(det_lists)): + det_item = det_lists[index] + det_wav_item = det_item.get(next(iter(det_item))) + if det_wav_item['keyword'] == keyword: + confidence = float(det_wav_item['confidence']) + if confidence >= threshold_cur: + det_count += 1 + + for index in range(len(fa_lists)): + fa_item = fa_lists[index] + fa_wav_item = fa_item.get(next(iter(fa_item))) + if fa_wav_item['keyword'] == keyword: + confidence = float(fa_wav_item['confidence']) + if confidence >= threshold_cur: + fa_count += 1 + + output_item = { + 'threshold': round(threshold_cur, 3), + 'recall': round(float(det_count / pos_wav_count), 6), + 'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) + } + output.append(output_item) + + threshold_cur += step + + return output diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 942d17c3..1bc06ce3 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,6 +5,7 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image +from .kws import WavToLists from .multi_modal import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py new file mode 100644 index 00000000..d69e8283 --- /dev/null +++ b/modelscope/preprocessors/kws.py @@ -0,0 +1,253 @@ +import os +import shutil +import stat +from pathlib import Path +from typing import Any, Dict, List + +import yaml + +from modelscope.metainfo import Preprocessors +from modelscope.models.base import Model +from modelscope.utils.constant import Fields +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['WavToLists'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=Preprocessors.wav_to_lists) +class WavToLists(Preprocessor): + """generate audio lists file from wav + + Args: + workspace (str): store temporarily kws intermedium and result + """ + + def __init__(self, workspace: str = None): + # the workspace path + if len(workspace) == 0: + self._workspace = os.path.join(os.getcwd(), '.tmp') + else: + self._workspace = workspace + + if not os.path.exists(self._workspace): + os.mkdir(self._workspace) + + def __call__(self, + model: Model = None, + kws_type: str = None, + wav_path: List[str] = None) -> Dict[str, Any]: + """Call functions to load model and wav. + + Args: + model (Model): model should be provided + kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc + wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path + Returns: + Dict[str, Any]: the kws result + """ + + assert model is not None, 'preprocess kws model should be provided' + assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' + ], f'preprocess kws_type {kws_type} is invalid' + assert wav_path[0] is not None or wav_path[ + 1] is not None, 'preprocess wav_path is invalid' + + self._model = model + out = self.forward(self._model.forward(), kws_type, wav_path) + return out + + def forward(self, model: Dict[str, Any], kws_type: str, + wav_path: List[str]) -> Dict[str, Any]: + assert len(kws_type) > 0, 'preprocess kws_type is empty' + assert len( + model['config_path']) > 0, 'preprocess model[config_path] is empty' + assert os.path.exists( + model['config_path']), 'model config.yaml is absent' + + inputs = model.copy() + + inputs['kws_set'] = kws_type + inputs['workspace'] = self._workspace + if wav_path[0] is not None: + inputs['pos_wav_path'] = wav_path[0] + if wav_path[1] is not None: + inputs['neg_wav_path'] = wav_path[1] + + out = self._read_config(inputs) + out = self._generate_wav_lists(out) + + return out + + def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """read and parse config.yaml to get all model files + """ + + assert os.path.exists( + inputs['config_path']), 'model config yaml file does not exist' + + config_file = open(inputs['config_path']) + root = yaml.full_load(config_file) + config_file.close() + + inputs['cfg_file'] = root['cfg_file'] + inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'], + root['cfg_file']) + inputs['keyword_grammar'] = root['keyword_grammar'] + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], root['keyword_grammar']) + inputs['sample_rate'] = str(root['sample_rate']) + inputs['kws_tool'] = root['kws_tool'] + + if os.path.exists( + os.path.join(inputs['workspace'], inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join(inputs['workspace'], + inputs['kws_tool']) + elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join('/usr/bin', + inputs['kws_tool']) + elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) + + assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' + os.chmod(inputs['kws_tool_path'], + stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) + + self._config_checking(inputs) + return inputs + + def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """assemble wav lists + """ + + if inputs['kws_set'] == 'wav': + inputs['pos_num_thread'] = 1 + wave_scp_content: str = inputs['pos_wav_path'] + '\n' + + with open(os.path.join(inputs['pos_data_path'], 'wave.list'), + 'a') as f: + f.write(wave_scp_content) + + inputs['pos_wav_count'] = 1 + + if inputs['kws_set'] in ['pos_testsets', 'roc']: + # find all positive wave + wav_list = [] + wav_dir = inputs['pos_wav_path'] + wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['pos_wav_count'] = list_count + + if list_count <= 128: + inputs['pos_num_thread'] = list_count + j: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['pos_data_path'] + '/wave.' + str( + j) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + + else: + inputs['pos_num_thread'] = 128 + j: int = 0 + k: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['pos_data_path'] + '/wave.' + str( + k) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + k += 1 + if k >= 128: + k = 0 + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + # find all negative wave + wav_list = [] + wav_dir = inputs['neg_wav_path'] + wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['neg_wav_count'] = list_count + + if list_count <= 128: + inputs['neg_num_thread'] = list_count + j: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['neg_data_path'] + '/wave.' + str( + j) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + + else: + inputs['neg_num_thread'] = 128 + j: int = 0 + k: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['neg_data_path'] + '/wave.' + str( + k) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + k += 1 + if k >= 128: + k = 0 + + return inputs + + def _recursion_dir_all_wave(self, wav_list, + dir_path: str) -> Dict[str, Any]: + dir_files = os.listdir(dir_path) + for file in dir_files: + file_path = os.path.join(dir_path, file) + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + elif os.path.isdir(file_path): + self._recursion_dir_all_wave(wav_list, file_path) + + return wav_list + + def _config_checking(self, inputs: Dict[str, Any]): + + if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + inputs['pos_data_path'] = os.path.join(inputs['workspace'], + 'pos_data') + if not os.path.exists(inputs['pos_data_path']): + os.mkdir(inputs['pos_data_path']) + else: + shutil.rmtree(inputs['pos_data_path']) + os.mkdir(inputs['pos_data_path']) + + inputs['pos_dump_path'] = os.path.join(inputs['workspace'], + 'pos_dump') + if not os.path.exists(inputs['pos_dump_path']): + os.mkdir(inputs['pos_dump_path']) + else: + shutil.rmtree(inputs['pos_dump_path']) + os.mkdir(inputs['pos_dump_path']) + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + inputs['neg_data_path'] = os.path.join(inputs['workspace'], + 'neg_data') + if not os.path.exists(inputs['neg_data_path']): + os.mkdir(inputs['neg_data_path']) + else: + shutil.rmtree(inputs['neg_data_path']) + os.mkdir(inputs['neg_data_path']) + + inputs['neg_dump_path'] = os.path.join(inputs['workspace'], + 'neg_dump') + if not os.path.exists(inputs['neg_dump_path']): + os.mkdir(inputs['neg_dump_path']) + else: + shutil.rmtree(inputs['neg_dump_path']) + os.mkdir(inputs['neg_dump_path']) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2045efb6..f2215359 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -52,6 +52,7 @@ class Tasks(object): auto_speech_recognition = 'auto-speech-recognition' text_to_speech = 'text-to-speech' speech_signal_process = 'speech-signal-process' + key_word_spotting = 'key-word-spotting' # multi-modal tasks image_captioning = 'image-captioning' diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py new file mode 100644 index 00000000..e82a4211 --- /dev/null +++ b/tests/pipelines/test_key_word_spotting.py @@ -0,0 +1,334 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tarfile +import unittest + +import requests + +from modelscope.metainfo import Pipelines, Preprocessors +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.constant import Fields, InputFields, Tasks +from modelscope.utils.test_utils import test_level + +KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' + +POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' +POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE + +POS_TESTSETS_FILE = 'pos_testsets.tar.gz' +POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' + +NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' +NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' + + +def un_tar_gz(fname, dirs): + t = tarfile.open(fname) + t.extractall(path=dirs) + + +class KeyWordSpottingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' + self.workspace = os.path.join(os.getcwd(), '.tmp') + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + if os.path.exists(self.workspace): + shutil.rmtree(self.workspace) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'wav' + + # downloading wav file + wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) + if not os.path.exists(wav_file_path): + r = requests.get(POS_WAV_URL) + with open(wav_file_path, 'wb') as f: + f.write(r.content) + + # downloading kwsbp + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('detected')) + """ + kws result json format example: + { + 'wav_count': 1, + 'kws_set': 'wav', + 'wav_time': 9.132938, + 'keywords': ['小云小云'], + 'detected': True, + 'confidence': 0.990368 + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_wav keywords: ', kws_result['keywords']) + print('test_run_with_wav detected result: ', kws_result['detected']) + print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_pos_testsets(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'pos_testsets' + + # downloading pos_testsets file + testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(POS_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(POS_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # wav_file_path = /.tmp_pos_testsets/pos_testsets/ + wav_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the pos_testsets file + if not os.path.exists(wav_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('recall')) + """ + kws result json format example: + { + 'wav_count': 450, + 'kws_set': 'pos_testsets', + 'wav_time': 3013.759254, + 'keywords': ["小云小云"], + 'recall': 0.953333, + 'detected_count': 429, + 'rejected_count': 21, + 'rejected': [ + 'yyy.wav', + 'zzz.wav', + ...... + ] + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_pos_testsets keywords: ', + kws_result['keywords']) + print('test_run_with_pos_testsets recall: ', kws_result['recall']) + print('test_run_with_pos_testsets wave time(seconds): ', + kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_neg_testsets(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'neg_testsets' + + # downloading neg_testsets file + testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(NEG_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(NEG_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # wav_file_path = /.tmp_neg_testsets/neg_testsets/ + wav_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the neg_testsets file + if not os.path.exists(wav_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[None, wav_file_path]) + self.assertTrue(kws_result.__contains__('fa_rate')) + """ + kws result json format example: + { + 'wav_count': 751, + 'kws_set': 'neg_testsets', + 'wav_time': 3572.180812, + 'keywords': ['小云小云'], + 'fa_rate': 0.001332, + 'fa_per_hour': 1.007788, + 'detected_count': 1, + 'rejected_count': 750, + 'detected': [ + { + '6.wav': { + 'confidence': '0.321170' + } + } + ] + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_neg_testsets keywords: ', + kws_result['keywords']) + print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) + print('test_run_with_neg_testsets fa per hour: ', + kws_result['fa_per_hour']) + print('test_run_with_neg_testsets wave time(seconds): ', + kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_roc(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'roc' + + # downloading neg_testsets file + testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(NEG_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(NEG_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # neg_file_path = /.tmp_roc/neg_testsets/ + neg_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the neg_testsets file + if not os.path.exists(neg_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading pos_testsets file + testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(POS_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(POS_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # pos_file_path = /.tmp_roc/pos_testsets/ + pos_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the pos_testsets file + if not os.path.exists(pos_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[pos_file_path, neg_file_path]) + """ + kws result json format example: + { + 'kws_set': 'roc', + 'keywords': ['小云小云'], + '小云小云': [ + {'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, + {'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, + ...... + {'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} + ] + } + """ + if kws_result.__contains__('keywords'): + find_keyword = kws_result['keywords'][0] + print('test_run_with_roc keywords: ', find_keyword) + keyword_list = kws_result[find_keyword] + for item in iter(keyword_list): + threshold: float = item['threshold'] + recall: float = item['recall'] + fa_per_hour: float = item['fa_per_hour'] + print(' threshold:', threshold, ' recall:', recall, + ' fa_per_hour:', fa_per_hour) + + +if __name__ == '__main__': + unittest.main()