diff --git a/modelscope/models/nlp/text_generation_model.py b/modelscope/models/nlp/text_generation_model.py index ebefc8d1..8feac691 100644 --- a/modelscope/models/nlp/text_generation_model.py +++ b/modelscope/models/nlp/text_generation_model.py @@ -4,11 +4,11 @@ from modelscope.utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS -__all__ = ['PalmForTextGenerationModel'] +__all__ = ['PalmForTextGeneration'] @MODELS.register_module(Tasks.text_generation, module_name=r'palm') -class PalmForTextGenerationModel(Model): +class PalmForTextGeneration(Model): def __init__(self, model_dir: str, *args, **kwargs): """initialize the text generation model from the `model_dir` path. diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index d4ad0c3f..83d1641e 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -16,7 +16,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.sentence_similarity: ('sbert-base-chinese-sentence-similarity', 'damo/nlp_structbert_sentence-similarity_chinese-base'), - Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), + Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index ea30a115..8b6bf8a9 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Union from modelscope.models import Model -from modelscope.models.nlp import PalmForTextGenerationModel +from modelscope.models.nlp import PalmForTextGeneration from modelscope.preprocessors import TextGenerationPreprocessor from modelscope.utils.constant import Tasks from ..base import Pipeline, Tensor @@ -14,7 +14,7 @@ __all__ = ['TextGenerationPipeline'] class TextGenerationPipeline(Pipeline): def __init__(self, - model: Union[PalmForTextGenerationModel, str], + model: Union[PalmForTextGeneration, str], preprocessor: Optional[TextGenerationPreprocessor] = None, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction @@ -24,8 +24,7 @@ class TextGenerationPipeline(Pipeline): preprocessor (SequenceClassificationPreprocessor): a preprocessor instance """ sc_model = model if isinstance( - model, - PalmForTextGenerationModel) else Model.from_pretrained(model) + model, PalmForTextGeneration) else Model.from_pretrained(model) if preprocessor is None: preprocessor = TextGenerationPreprocessor( sc_model.model_dir, diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index ba5d05ad..676153bf 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -17,7 +17,7 @@ from modelscope.utils.test_utils import test_level class ImageMattingTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/cv_unet_image-matting_damo' + self.model_id = 'damo/cv_unet_image-matting' # switch to False if downloading everytime is not desired purge_cache = True if purge_cache: diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index f98e135d..39d57ff7 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -4,7 +4,7 @@ import unittest from maas_hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import PalmForTextGenerationModel +from modelscope.models.nlp import PalmForTextGeneration from modelscope.pipelines import TextGenerationPipeline, pipeline from modelscope.preprocessors import TextGenerationPreprocessor from modelscope.utils.constant import Tasks @@ -21,7 +21,7 @@ class TextGenerationTest(unittest.TestCase): cache_path = snapshot_download(self.model_id) preprocessor = TextGenerationPreprocessor( cache_path, first_sequence='sentence', second_sequence=None) - model = PalmForTextGenerationModel( + model = PalmForTextGeneration( cache_path, tokenizer=preprocessor.tokenizer) pipeline1 = TextGenerationPipeline(model, preprocessor) pipeline2 = pipeline(