diff --git a/modelscope/trainers/nlp/sequence_classification_trainer.py b/modelscope/trainers/nlp/sequence_classification_trainer.py index b2b759fa..7ae5576f 100644 --- a/modelscope/trainers/nlp/sequence_classification_trainer.py +++ b/modelscope/trainers/nlp/sequence_classification_trainer.py @@ -14,8 +14,7 @@ PATH = None logger = get_logger(PATH) -@TRAINERS.register_module( - Tasks.text_classification, module_name=r'bert-sentiment-analysis') +@TRAINERS.register_module(module_name=r'bert-sentiment-analysis') class SequenceClassificationTrainer(BaseTrainer): def __init__(self, cfg_file: str, *args, **kwargs): diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 9b37252b..2e1f8672 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -77,19 +77,6 @@ class Registry(object): self._modules[group_key][module_name] = module_cls module_cls.group_key = group_key - if module_name in self._modules[default_group]: - if id(self._modules[default_group][module_name]) == id(module_cls): - return - else: - logger.warning(f'{module_name} is already registered in ' - f'{self._name}[{default_group}] and will ' - 'be overwritten') - logger.warning(f'{self._modules[default_group][module_name]}' - f'to {module_cls}') - # also register module in the default group for faster access - # only by module name - self._modules[default_group][module_name] = module_cls - def register_module(self, group_key: str = default_group, module_name: str = None, diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index c642ed4b..446c6082 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -74,9 +74,9 @@ class CustomPipelineTest(unittest.TestCase): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs - self.assertTrue('custom-image' in PIPELINES.modules[default_group]) + self.assertTrue('custom-image' in PIPELINES.modules[dummy_task]) add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) - pipe = pipeline(pipeline_name='custom-image') + pipe = pipeline(task=dummy_task, pipeline_name='custom-image') pipe2 = pipeline(dummy_task) self.assertTrue(type(pipe) is type(pipe2)) diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index d0f62461..b01e3f21 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -66,6 +66,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -123,6 +124,7 @@ class KeyWordSpottingTest(unittest.TestCase): keywords = [{'keyword': '播放音乐'}] kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor, @@ -192,6 +194,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -263,6 +266,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -357,6 +361,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index e92047d6..c371d80a 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -12,7 +12,7 @@ from modelscope.metainfo import Pipelines, Preprocessors from modelscope.models import Model from modelscope.pipelines import pipeline from modelscope.preprocessors import build_preprocessor -from modelscope.utils.constant import Fields +from modelscope.utils.constant import Fields, Tasks from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level @@ -43,6 +43,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): self.assertTrue(voc is not None) sambert_tts = pipeline( + task=Tasks.text_to_speech, pipeline_name=Pipelines.sambert_hifigan_16k_tts, config_file='', model=[am, voc],