wenmeng.zwm 3 years ago
parent
commit
bb57020daa
5 changed files with 10 additions and 18 deletions
  1. +1
    -2
      modelscope/trainers/nlp/sequence_classification_trainer.py
  2. +0
    -13
      modelscope/utils/registry.py
  3. +2
    -2
      tests/pipelines/test_base.py
  4. +5
    -0
      tests/pipelines/test_key_word_spotting.py
  5. +2
    -1
      tests/pipelines/test_text_to_speech.py

+ 1
- 2
modelscope/trainers/nlp/sequence_classification_trainer.py View File

@@ -14,8 +14,7 @@ PATH = None
logger = get_logger(PATH)


@TRAINERS.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
@TRAINERS.register_module(module_name=r'bert-sentiment-analysis')
class SequenceClassificationTrainer(BaseTrainer):

def __init__(self, cfg_file: str, *args, **kwargs):


+ 0
- 13
modelscope/utils/registry.py View File

@@ -77,19 +77,6 @@ class Registry(object):
self._modules[group_key][module_name] = module_cls
module_cls.group_key = group_key

if module_name in self._modules[default_group]:
if id(self._modules[default_group][module_name]) == id(module_cls):
return
else:
logger.warning(f'{module_name} is already registered in '
f'{self._name}[{default_group}] and will '
'be overwritten')
logger.warning(f'{self._modules[default_group][module_name]}'
f'to {module_cls}')
# also register module in the default group for faster access
# only by module name
self._modules[default_group][module_name] = module_cls

def register_module(self,
group_key: str = default_group,
module_name: str = None,


+ 2
- 2
tests/pipelines/test_base.py View File

@@ -74,9 +74,9 @@ class CustomPipelineTest(unittest.TestCase):
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

self.assertTrue('custom-image' in PIPELINES.modules[default_group])
self.assertTrue('custom-image' in PIPELINES.modules[dummy_task])
add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
pipe = pipeline(pipeline_name='custom-image')
pipe = pipeline(task=dummy_task, pipeline_name='custom-image')
pipe2 = pipeline(dummy_task)
self.assertTrue(type(pipe) is type(pipe2))



+ 5
- 0
tests/pipelines/test_key_word_spotting.py View File

@@ -66,6 +66,7 @@ class KeyWordSpottingTest(unittest.TestCase):
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
@@ -123,6 +124,7 @@ class KeyWordSpottingTest(unittest.TestCase):
keywords = [{'keyword': '播放音乐'}]

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor,
@@ -192,6 +194,7 @@ class KeyWordSpottingTest(unittest.TestCase):
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
@@ -263,6 +266,7 @@ class KeyWordSpottingTest(unittest.TestCase):
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
@@ -357,6 +361,7 @@ class KeyWordSpottingTest(unittest.TestCase):
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
task=Tasks.key_word_spotting,
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)


+ 2
- 1
tests/pipelines/test_text_to_speech.py View File

@@ -12,7 +12,7 @@ from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.constant import Fields
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

@@ -43,6 +43,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
self.assertTrue(voc is not None)

sambert_tts = pipeline(
task=Tasks.text_to_speech,
pipeline_name=Pipelines.sambert_hifigan_16k_tts,
config_file='',
model=[am, voc],


Loading…
Cancel
Save