Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9068994master
| @@ -4,11 +4,11 @@ from modelscope.utils.constant import Tasks | |||||
| from ..base import Model, Tensor | from ..base import Model, Tensor | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| __all__ = ['PalmForTextGenerationModel'] | |||||
| __all__ = ['PalmForTextGeneration'] | |||||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm') | @MODELS.register_module(Tasks.text_generation, module_name=r'palm') | ||||
| class PalmForTextGenerationModel(Model): | |||||
| class PalmForTextGeneration(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| """initialize the text generation model from the `model_dir` path. | """initialize the text generation model from the `model_dir` path. | ||||
| @@ -16,7 +16,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.sentence_similarity: | Tasks.sentence_similarity: | ||||
| ('sbert-base-chinese-sentence-similarity', | ('sbert-base-chinese-sentence-similarity', | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), | |||||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | |||||
| Tasks.text_classification: | Tasks.text_classification: | ||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ||||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | ||||
| @@ -1,7 +1,7 @@ | |||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import PalmForTextGenerationModel | |||||
| from modelscope.models.nlp import PalmForTextGeneration | |||||
| from modelscope.preprocessors import TextGenerationPreprocessor | from modelscope.preprocessors import TextGenerationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from ..base import Pipeline, Tensor | from ..base import Pipeline, Tensor | ||||
| @@ -14,7 +14,7 @@ __all__ = ['TextGenerationPipeline'] | |||||
| class TextGenerationPipeline(Pipeline): | class TextGenerationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[PalmForTextGenerationModel, str], | |||||
| model: Union[PalmForTextGeneration, str], | |||||
| preprocessor: Optional[TextGenerationPreprocessor] = None, | preprocessor: Optional[TextGenerationPreprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
| @@ -24,8 +24,7 @@ class TextGenerationPipeline(Pipeline): | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | ||||
| """ | """ | ||||
| sc_model = model if isinstance( | sc_model = model if isinstance( | ||||
| model, | |||||
| PalmForTextGenerationModel) else Model.from_pretrained(model) | |||||
| model, PalmForTextGeneration) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TextGenerationPreprocessor( | preprocessor = TextGenerationPreprocessor( | ||||
| sc_model.model_dir, | sc_model.model_dir, | ||||
| @@ -17,7 +17,7 @@ from modelscope.utils.test_utils import test_level | |||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/cv_unet_image-matting_damo' | |||||
| self.model_id = 'damo/cv_unet_image-matting' | |||||
| # switch to False if downloading everytime is not desired | # switch to False if downloading everytime is not desired | ||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| @@ -4,7 +4,7 @@ import unittest | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.models.nlp import PalmForTextGenerationModel | |||||
| from modelscope.models.nlp import PalmForTextGeneration | |||||
| from modelscope.pipelines import TextGenerationPipeline, pipeline | from modelscope.pipelines import TextGenerationPipeline, pipeline | ||||
| from modelscope.preprocessors import TextGenerationPreprocessor | from modelscope.preprocessors import TextGenerationPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -21,7 +21,7 @@ class TextGenerationTest(unittest.TestCase): | |||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| preprocessor = TextGenerationPreprocessor( | preprocessor = TextGenerationPreprocessor( | ||||
| cache_path, first_sequence='sentence', second_sequence=None) | cache_path, first_sequence='sentence', second_sequence=None) | ||||
| model = PalmForTextGenerationModel( | |||||
| model = PalmForTextGeneration( | |||||
| cache_path, tokenizer=preprocessor.tokenizer) | cache_path, tokenizer=preprocessor.tokenizer) | ||||
| pipeline1 = TextGenerationPipeline(model, preprocessor) | pipeline1 = TextGenerationPipeline(model, preprocessor) | ||||
| pipeline2 = pipeline( | pipeline2 = pipeline( | ||||