Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10412829master
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| @@ -37,6 +37,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| **kwargs) -> Dict[str, Any]: | **kwargs) -> Dict[str, Any]: | ||||
| if 'keywords' in kwargs.keys(): | if 'keywords' in kwargs.keys(): | ||||
| self.keywords = kwargs['keywords'] | self.keywords = kwargs['keywords'] | ||||
| if isinstance(self.keywords, str): | |||||
| word_list = [] | |||||
| word = {} | |||||
| word['keyword'] = self.keywords | |||||
| word_list.append(word) | |||||
| self.keywords = word_list | |||||
| else: | else: | ||||
| self.keywords = None | self.keywords = None | ||||
| @@ -96,6 +102,9 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| pos_list=pos_kws_list, | pos_list=pos_kws_list, | ||||
| neg_list=neg_kws_list) | neg_list=neg_kws_list) | ||||
| if 'kws_list' not in rst_dict: | |||||
| rst_dict['kws_list'] = [] | |||||
| return rst_dict | return rst_dict | ||||
| def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
| @@ -245,7 +245,7 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_wav_by_customized_keywords(self): | def test_run_with_wav_by_customized_keywords(self): | ||||
| keywords = [{'keyword': '播放音乐'}] | |||||
| keywords = '播放音乐' | |||||
| kws_result = self.run_pipeline( | kws_result = self.run_pipeline( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||