yingda.chen 3 years ago
parent
commit
ad8e080e37
5 changed files with 9 additions and 10 deletions
  1. +2
    -2
      modelscope/models/nlp/text_generation_model.py
  2. +1
    -1
      modelscope/pipelines/builder.py
  3. +3
    -4
      modelscope/pipelines/nlp/text_generation_pipeline.py
  4. +1
    -1
      tests/pipelines/test_image_matting.py
  5. +2
    -2
      tests/pipelines/test_text_generation.py

+ 2
- 2
modelscope/models/nlp/text_generation_model.py View File

@@ -4,11 +4,11 @@ from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['PalmForTextGenerationModel']
__all__ = ['PalmForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=r'palm')
class PalmForTextGenerationModel(Model):
class PalmForTextGeneration(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.


+ 1
- 1
modelscope/pipelines/builder.py View File

@@ -16,7 +16,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.sentence_similarity:
('sbert-base-chinese-sentence-similarity',
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),


+ 3
- 4
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional, Union

from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGenerationModel
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
@@ -14,7 +14,7 @@ __all__ = ['TextGenerationPipeline']
class TextGenerationPipeline(Pipeline):

def __init__(self,
model: Union[PalmForTextGenerationModel, str],
model: Union[PalmForTextGeneration, str],
preprocessor: Optional[TextGenerationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
@@ -24,8 +24,7 @@ class TextGenerationPipeline(Pipeline):
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""
sc_model = model if isinstance(
model,
PalmForTextGenerationModel) else Model.from_pretrained(model)
model, PalmForTextGeneration) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
sc_model.model_dir,


+ 1
- 1
tests/pipelines/test_image_matting.py View File

@@ -17,7 +17,7 @@ from modelscope.utils.test_utils import test_level
class ImageMattingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-matting_damo'
self.model_id = 'damo/cv_unet_image-matting'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:


+ 2
- 2
tests/pipelines/test_text_generation.py View File

@@ -4,7 +4,7 @@ import unittest
from maas_hub.snapshot_download import snapshot_download

from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGenerationModel
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.pipelines import TextGenerationPipeline, pipeline
from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks
@@ -21,7 +21,7 @@ class TextGenerationTest(unittest.TestCase):
cache_path = snapshot_download(self.model_id)
preprocessor = TextGenerationPreprocessor(
cache_path, first_sequence='sentence', second_sequence=None)
model = PalmForTextGenerationModel(
model = PalmForTextGeneration(
cache_path, tokenizer=preprocessor.tokenizer)
pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline(


Loading…
Cancel
Save