 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9229852master
| @@ -14,8 +14,7 @@ PATH = None | |||||
| logger = get_logger(PATH) | logger = get_logger(PATH) | ||||
| @TRAINERS.register_module( | |||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
| @TRAINERS.register_module(module_name=r'bert-sentiment-analysis') | |||||
| class SequenceClassificationTrainer(BaseTrainer): | class SequenceClassificationTrainer(BaseTrainer): | ||||
| def __init__(self, cfg_file: str, *args, **kwargs): | def __init__(self, cfg_file: str, *args, **kwargs): | ||||
| @@ -77,19 +77,6 @@ class Registry(object): | |||||
| self._modules[group_key][module_name] = module_cls | self._modules[group_key][module_name] = module_cls | ||||
| module_cls.group_key = group_key | module_cls.group_key = group_key | ||||
| if module_name in self._modules[default_group]: | |||||
| if id(self._modules[default_group][module_name]) == id(module_cls): | |||||
| return | |||||
| else: | |||||
| logger.warning(f'{module_name} is already registered in ' | |||||
| f'{self._name}[{default_group}] and will ' | |||||
| 'be overwritten') | |||||
| logger.warning(f'{self._modules[default_group][module_name]}' | |||||
| f'to {module_cls}') | |||||
| # also register module in the default group for faster access | |||||
| # only by module name | |||||
| self._modules[default_group][module_name] = module_cls | |||||
| def register_module(self, | def register_module(self, | ||||
| group_key: str = default_group, | group_key: str = default_group, | ||||
| module_name: str = None, | module_name: str = None, | ||||
| @@ -74,9 +74,9 @@ class CustomPipelineTest(unittest.TestCase): | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||||
| self.assertTrue('custom-image' in PIPELINES.modules[dummy_task]) | |||||
| add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) | add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) | ||||
| pipe = pipeline(pipeline_name='custom-image') | |||||
| pipe = pipeline(task=dummy_task, pipeline_name='custom-image') | |||||
| pipe2 = pipeline(dummy_task) | pipe2 = pipeline(dummy_task) | ||||
| self.assertTrue(type(pipe) is type(pipe2)) | self.assertTrue(type(pipe) is type(pipe2)) | ||||
| @@ -66,6 +66,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| self.assertTrue(preprocessor is not None) | self.assertTrue(preprocessor is not None) | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.key_word_spotting, | |||||
| pipeline_name=Pipelines.kws_kwsbp, | pipeline_name=Pipelines.kws_kwsbp, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| @@ -123,6 +124,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| keywords = [{'keyword': '播放音乐'}] | keywords = [{'keyword': '播放音乐'}] | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.key_word_spotting, | |||||
| pipeline_name=Pipelines.kws_kwsbp, | pipeline_name=Pipelines.kws_kwsbp, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor, | preprocessor=preprocessor, | ||||
| @@ -192,6 +194,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| self.assertTrue(preprocessor is not None) | self.assertTrue(preprocessor is not None) | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.key_word_spotting, | |||||
| pipeline_name=Pipelines.kws_kwsbp, | pipeline_name=Pipelines.kws_kwsbp, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| @@ -263,6 +266,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| self.assertTrue(preprocessor is not None) | self.assertTrue(preprocessor is not None) | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.key_word_spotting, | |||||
| pipeline_name=Pipelines.kws_kwsbp, | pipeline_name=Pipelines.kws_kwsbp, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| @@ -357,6 +361,7 @@ class KeyWordSpottingTest(unittest.TestCase): | |||||
| self.assertTrue(preprocessor is not None) | self.assertTrue(preprocessor is not None) | ||||
| kwsbp_16k_pipline = pipeline( | kwsbp_16k_pipline = pipeline( | ||||
| task=Tasks.key_word_spotting, | |||||
| pipeline_name=Pipelines.kws_kwsbp, | pipeline_name=Pipelines.kws_kwsbp, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| @@ -12,7 +12,7 @@ from modelscope.metainfo import Pipelines, Preprocessors | |||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.preprocessors import build_preprocessor | from modelscope.preprocessors import build_preprocessor | ||||
| from modelscope.utils.constant import Fields | |||||
| from modelscope.utils.constant import Fields, Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -43,6 +43,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||||
| self.assertTrue(voc is not None) | self.assertTrue(voc is not None) | ||||
| sambert_tts = pipeline( | sambert_tts = pipeline( | ||||
| task=Tasks.text_to_speech, | |||||
| pipeline_name=Pipelines.sambert_hifigan_16k_tts, | pipeline_name=Pipelines.sambert_hifigan_16k_tts, | ||||
| config_file='', | config_file='', | ||||
| model=[am, voc], | model=[am, voc], | ||||