|
|
|
@@ -1,7 +1,7 @@ |
|
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
|
|
from modelscope.models import Model |
|
|
|
from modelscope.models.nlp import PalmForTextGenerationModel |
|
|
|
from modelscope.models.nlp import PalmForTextGeneration |
|
|
|
from modelscope.preprocessors import TextGenerationPreprocessor |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
from ..base import Pipeline, Tensor |
|
|
|
@@ -14,7 +14,7 @@ __all__ = ['TextGenerationPipeline'] |
|
|
|
class TextGenerationPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
model: Union[PalmForTextGenerationModel, str], |
|
|
|
model: Union[PalmForTextGeneration, str], |
|
|
|
preprocessor: Optional[TextGenerationPreprocessor] = None, |
|
|
|
**kwargs): |
|
|
|
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction |
|
|
|
@@ -24,8 +24,7 @@ class TextGenerationPipeline(Pipeline): |
|
|
|
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance |
|
|
|
""" |
|
|
|
sc_model = model if isinstance( |
|
|
|
model, |
|
|
|
PalmForTextGenerationModel) else Model.from_pretrained(model) |
|
|
|
model, PalmForTextGeneration) else Model.from_pretrained(model) |
|
|
|
if preprocessor is None: |
|
|
|
preprocessor = TextGenerationPreprocessor( |
|
|
|
sc_model.model_dir, |
|
|
|
|