diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 45184ad7..bdefdd48 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -3,7 +3,7 @@ import os import shutil import stat import subprocess -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import json @@ -25,19 +25,21 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): def __init__(self, config_file: str = None, - model: Model = None, + model: Union[Model, str] = None, preprocessor: WavToLists = None, **kwargs): """use `model` and `preprocessor` to create a kws pipeline for prediction """ + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + super().__init__( config_file=config_file, model=model, preprocessor=preprocessor, **kwargs) assert model is not None, 'kws model should be provided' - assert preprocessor is not None, 'preprocessor is none' self._preprocessor = preprocessor self._model = model @@ -45,12 +47,17 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): if 'keywords' in kwargs.keys(): self._keywords = kwargs['keywords'] - print('self._keywords len: ', len(self._keywords)) - print('self._keywords: ', self._keywords) - def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: + def __call__(self, + kws_type: str, + wav_path: List[str], + workspace: str = None) -> Dict[str, Any]: assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'], f'kws_type {kws_type} is invalid' + + if self._preprocessor is None: + self._preprocessor = WavToLists(workspace=workspace) + output = self._preprocessor.forward(self._model.forward(), kws_type, wav_path) output = self.forward(output) diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index b01e3f21..91acaa66 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -40,7 +40,7 @@ class KeyWordSpottingTest(unittest.TestCase): def tearDown(self) -> None: if os.path.exists(self.workspace): - shutil.rmtree(self.workspace) + shutil.rmtree(os.path.join(self.workspace), ignore_errors=True) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav(self): @@ -57,23 +57,14 @@ class KeyWordSpottingTest(unittest.TestCase): with open(kwsbp_file_path, 'wb') as f: f.write(r.content) - model = Model.from_pretrained(self.model_id) - self.assertTrue(model is not None) - - cfg_preprocessor = dict( - type=Preprocessors.wav_to_lists, workspace=self.workspace) - preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) - self.assertTrue(preprocessor is not None) - kwsbp_16k_pipline = pipeline( - task=Tasks.key_word_spotting, - pipeline_name=Pipelines.kws_kwsbp, - model=model, - preprocessor=preprocessor) + task=Tasks.key_word_spotting, model=self.model_id) self.assertTrue(kwsbp_16k_pipline is not None) kws_result = kwsbp_16k_pipline( - kws_type=kws_set, wav_path=[wav_file_path, None]) + kws_type=kws_set, + wav_path=[wav_file_path, None], + workspace=self.workspace) self.assertTrue(kws_result.__contains__('detected')) """ kws result json format example: @@ -107,14 +98,6 @@ class KeyWordSpottingTest(unittest.TestCase): with open(kwsbp_file_path, 'wb') as f: f.write(r.content) - model = Model.from_pretrained(self.model_id) - self.assertTrue(model is not None) - - cfg_preprocessor = dict( - type=Preprocessors.wav_to_lists, workspace=self.workspace) - preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) - self.assertTrue(preprocessor is not None) - # customized keyword if you need. # full settings eg. # keywords = [ @@ -125,14 +108,14 @@ class KeyWordSpottingTest(unittest.TestCase): kwsbp_16k_pipline = pipeline( task=Tasks.key_word_spotting, - pipeline_name=Pipelines.kws_kwsbp, - model=model, - preprocessor=preprocessor, + model=self.model_id, keywords=keywords) self.assertTrue(kwsbp_16k_pipline is not None) kws_result = kwsbp_16k_pipline( - kws_type=kws_set, wav_path=[wav_file_path, None]) + kws_type=kws_set, + wav_path=[wav_file_path, None], + workspace=self.workspace) self.assertTrue(kws_result.__contains__('detected')) """ kws result json format example: @@ -185,23 +168,14 @@ class KeyWordSpottingTest(unittest.TestCase): with open(kwsbp_file_path, 'wb') as f: f.write(r.content) - model = Model.from_pretrained(self.model_id) - self.assertTrue(model is not None) - - cfg_preprocessor = dict( - type=Preprocessors.wav_to_lists, workspace=self.workspace) - preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) - self.assertTrue(preprocessor is not None) - kwsbp_16k_pipline = pipeline( - task=Tasks.key_word_spotting, - pipeline_name=Pipelines.kws_kwsbp, - model=model, - preprocessor=preprocessor) + task=Tasks.key_word_spotting, model=self.model_id) self.assertTrue(kwsbp_16k_pipline is not None) kws_result = kwsbp_16k_pipline( - kws_type=kws_set, wav_path=[wav_file_path, None]) + kws_type=kws_set, + wav_path=[wav_file_path, None], + workspace=self.workspace) self.assertTrue(kws_result.__contains__('recall')) """ kws result json format example: @@ -257,23 +231,14 @@ class KeyWordSpottingTest(unittest.TestCase): with open(kwsbp_file_path, 'wb') as f: f.write(r.content) - model = Model.from_pretrained(self.model_id) - self.assertTrue(model is not None) - - cfg_preprocessor = dict( - type=Preprocessors.wav_to_lists, workspace=self.workspace) - preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) - self.assertTrue(preprocessor is not None) - kwsbp_16k_pipline = pipeline( - task=Tasks.key_word_spotting, - pipeline_name=Pipelines.kws_kwsbp, - model=model, - preprocessor=preprocessor) + task=Tasks.key_word_spotting, model=self.model_id) self.assertTrue(kwsbp_16k_pipline is not None) kws_result = kwsbp_16k_pipline( - kws_type=kws_set, wav_path=[None, wav_file_path]) + kws_type=kws_set, + wav_path=[None, wav_file_path], + workspace=self.workspace) self.assertTrue(kws_result.__contains__('fa_rate')) """ kws result json format example: @@ -352,23 +317,14 @@ class KeyWordSpottingTest(unittest.TestCase): with open(kwsbp_file_path, 'wb') as f: f.write(r.content) - model = Model.from_pretrained(self.model_id) - self.assertTrue(model is not None) - - cfg_preprocessor = dict( - type=Preprocessors.wav_to_lists, workspace=self.workspace) - preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) - self.assertTrue(preprocessor is not None) - kwsbp_16k_pipline = pipeline( - task=Tasks.key_word_spotting, - pipeline_name=Pipelines.kws_kwsbp, - model=model, - preprocessor=preprocessor) + task=Tasks.key_word_spotting, model=self.model_id) self.assertTrue(kwsbp_16k_pipline is not None) kws_result = kwsbp_16k_pipline( - kws_type=kws_set, wav_path=[pos_file_path, neg_file_path]) + kws_type=kws_set, + wav_path=[pos_file_path, neg_file_path], + workspace=self.workspace) """ kws result json format example: {