yingda.chen 3 years ago
parent
commit
ad8e080e37
5 changed files with 9 additions and 10 deletions
  1. +2
    -2
      modelscope/models/nlp/text_generation_model.py
  2. +1
    -1
      modelscope/pipelines/builder.py
  3. +3
    -4
      modelscope/pipelines/nlp/text_generation_pipeline.py
  4. +1
    -1
      tests/pipelines/test_image_matting.py
  5. +2
    -2
      tests/pipelines/test_text_generation.py

+ 2
- 2
modelscope/models/nlp/text_generation_model.py View File

@@ -4,11 +4,11 @@ from modelscope.utils.constant import Tasks
from ..base import Model, Tensor from ..base import Model, Tensor
from ..builder import MODELS from ..builder import MODELS


__all__ = ['PalmForTextGenerationModel']
__all__ = ['PalmForTextGeneration']




@MODELS.register_module(Tasks.text_generation, module_name=r'palm') @MODELS.register_module(Tasks.text_generation, module_name=r'palm')
class PalmForTextGenerationModel(Model):
class PalmForTextGeneration(Model):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path. """initialize the text generation model from the `model_dir` path.


+ 1
- 1
modelscope/pipelines/builder.py View File

@@ -16,7 +16,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.sentence_similarity: Tasks.sentence_similarity:
('sbert-base-chinese-sentence-similarity', ('sbert-base-chinese-sentence-similarity',
'damo/nlp_structbert_sentence-similarity_chinese-base'), 'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.text_classification: Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'), ('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),


+ 3
- 4
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional, Union from typing import Dict, Optional, Union


from modelscope.models import Model from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGenerationModel
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.preprocessors import TextGenerationPreprocessor from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor from ..base import Pipeline, Tensor
@@ -14,7 +14,7 @@ __all__ = ['TextGenerationPipeline']
class TextGenerationPipeline(Pipeline): class TextGenerationPipeline(Pipeline):


def __init__(self, def __init__(self,
model: Union[PalmForTextGenerationModel, str],
model: Union[PalmForTextGeneration, str],
preprocessor: Optional[TextGenerationPreprocessor] = None, preprocessor: Optional[TextGenerationPreprocessor] = None,
**kwargs): **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
@@ -24,8 +24,7 @@ class TextGenerationPipeline(Pipeline):
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
""" """
sc_model = model if isinstance( sc_model = model if isinstance(
model,
PalmForTextGenerationModel) else Model.from_pretrained(model)
model, PalmForTextGeneration) else Model.from_pretrained(model)
if preprocessor is None: if preprocessor is None:
preprocessor = TextGenerationPreprocessor( preprocessor = TextGenerationPreprocessor(
sc_model.model_dir, sc_model.model_dir,


+ 1
- 1
tests/pipelines/test_image_matting.py View File

@@ -17,7 +17,7 @@ from modelscope.utils.test_utils import test_level
class ImageMattingTest(unittest.TestCase): class ImageMattingTest(unittest.TestCase):


def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-matting_damo'
self.model_id = 'damo/cv_unet_image-matting'
# switch to False if downloading everytime is not desired # switch to False if downloading everytime is not desired
purge_cache = True purge_cache = True
if purge_cache: if purge_cache:


+ 2
- 2
tests/pipelines/test_text_generation.py View File

@@ -4,7 +4,7 @@ import unittest
from maas_hub.snapshot_download import snapshot_download from maas_hub.snapshot_download import snapshot_download


from modelscope.models import Model from modelscope.models import Model
from modelscope.models.nlp import PalmForTextGenerationModel
from modelscope.models.nlp import PalmForTextGeneration
from modelscope.pipelines import TextGenerationPipeline, pipeline from modelscope.pipelines import TextGenerationPipeline, pipeline
from modelscope.preprocessors import TextGenerationPreprocessor from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -21,7 +21,7 @@ class TextGenerationTest(unittest.TestCase):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
preprocessor = TextGenerationPreprocessor( preprocessor = TextGenerationPreprocessor(
cache_path, first_sequence='sentence', second_sequence=None) cache_path, first_sequence='sentence', second_sequence=None)
model = PalmForTextGenerationModel(
model = PalmForTextGeneration(
cache_path, tokenizer=preprocessor.tokenizer) cache_path, tokenizer=preprocessor.tokenizer)
pipeline1 = TextGenerationPipeline(model, preprocessor) pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline( pipeline2 = pipeline(


Loading…
Cancel
Save