Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9667657master
| @@ -9,8 +9,7 @@ from modelscope.utils.constant import Tasks | |||||
| __all__ = ['GenericKeyWordSpotting'] | __all__ = ['GenericKeyWordSpotting'] | ||||
| @MODELS.register_module( | |||||
| Tasks.auto_speech_recognition, module_name=Models.kws_kwsbp) | |||||
| @MODELS.register_module(Tasks.keyword_spotting, module_name=Models.kws_kwsbp) | |||||
| class GenericKeyWordSpotting(Model): | class GenericKeyWordSpotting(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -17,7 +17,7 @@ __all__ = ['KeyWordSpottingKwsbpPipeline'] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.auto_speech_recognition, module_name=Pipelines.kws_kwsbp) | |||||
| Tasks.keyword_spotting, module_name=Pipelines.kws_kwsbp) | |||||
| class KeyWordSpottingKwsbpPipeline(Pipeline): | class KeyWordSpottingKwsbpPipeline(Pipeline): | ||||
| """KWS Pipeline - key word spotting decoding | """KWS Pipeline - key word spotting decoding | ||||
| """ | """ | ||||
| @@ -98,6 +98,7 @@ class AudioTasks(object): | |||||
| speech_signal_process = 'speech-signal-process' | speech_signal_process = 'speech-signal-process' | ||||
| acoustic_echo_cancellation = 'acoustic-echo-cancellation' | acoustic_echo_cancellation = 'acoustic-echo-cancellation' | ||||
| acoustic_noise_suppression = 'acoustic-noise-suppression' | acoustic_noise_suppression = 'acoustic-noise-suppression' | ||||
| keyword_spotting = 'keyword-spotting' | |||||
| class MultiModalTasks(object): | class MultiModalTasks(object): | ||||
| @@ -139,7 +139,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| } | } | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' | |||||
| self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun' | |||||
| self.workspace = os.path.join(os.getcwd(), '.tmp') | self.workspace = os.path.join(os.getcwd(), '.tmp') | ||||
| if not os.path.exists(self.workspace): | if not os.path.exists(self.workspace): | ||||
| os.mkdir(self.workspace) | os.mkdir(self.workspace) | ||||
| @@ -153,7 +153,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| audio_in: Union[List[str], str, bytes], | audio_in: Union[List[str], str, bytes], | ||||
| keywords: List[str] = None) -> Dict[str, Any]: | keywords: List[str] = None) -> Dict[str, Any]: | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.auto_speech_recognition, model=model_id) | |||||
| task=Tasks.keyword_spotting, model=model_id) | |||||
| kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords) | kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords) | ||||