diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py index b886f124..cd8e20cf 100644 --- a/modelscope/models/nlp/task_models/text_generation.py +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -2,7 +2,7 @@ from typing import Any, Dict import numpy as np -from transformers.modeling_utils import GenerationMixin +from transformers.modeling_utils import PreTrainedModel from modelscope.metainfo import TaskModels from modelscope.models.builder import MODELS @@ -17,8 +17,7 @@ __all__ = ['TaskModelForTextGeneration'] @MODELS.register_module( Tasks.text_generation, module_name=TaskModels.text_generation) -class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): - main_input_name = 'input_ids' +class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): def __init__(self, model_dir: str, *args, **kwargs): """initialize the text generation model from the `model_dir` path.