diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py index 861cca20..c1b7a0e4 100644 --- a/modelscope/models/audio/kws/generic_key_word_spotting.py +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -9,8 +9,7 @@ from modelscope.utils.constant import Tasks __all__ = ['GenericKeyWordSpotting'] -@MODELS.register_module( - Tasks.auto_speech_recognition, module_name=Models.kws_kwsbp) +@MODELS.register_module(Tasks.keyword_spotting, module_name=Models.kws_kwsbp) class GenericKeyWordSpotting(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 5d51593e..1f31766a 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -17,7 +17,7 @@ __all__ = ['KeyWordSpottingKwsbpPipeline'] @PIPELINES.register_module( - Tasks.auto_speech_recognition, module_name=Pipelines.kws_kwsbp) + Tasks.keyword_spotting, module_name=Pipelines.kws_kwsbp) class KeyWordSpottingKwsbpPipeline(Pipeline): """KWS Pipeline - key word spotting decoding """ diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4b49efdc..538aa3db 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -98,6 +98,7 @@ class AudioTasks(object): speech_signal_process = 'speech-signal-process' acoustic_echo_cancellation = 'acoustic-echo-cancellation' acoustic_noise_suppression = 'acoustic-noise-suppression' + keyword_spotting = 'keyword-spotting' class MultiModalTasks(object): diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index 17640934..32a853af 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -139,7 +139,7 @@ class KeyWordSpottingTest(unittest.TestCase): } def setUp(self) -> None: - self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' + self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun' self.workspace = os.path.join(os.getcwd(), '.tmp') if not os.path.exists(self.workspace): os.mkdir(self.workspace) @@ -153,7 +153,7 @@ class KeyWordSpottingTest(unittest.TestCase): audio_in: Union[List[str], str, bytes], keywords: List[str] = None) -> Dict[str, Any]: kwsbp_16k_pipline = pipeline( - task=Tasks.auto_speech_recognition, model=model_id) + task=Tasks.keyword_spotting, model=model_id) kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)