From 8184c86c5f6003439120764cb9e1d9249febc4ba Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Fri, 2 Dec 2022 17:52:19 +0800 Subject: [PATCH] [to #42322933] Fix bug for text generation task model Fixed the bug for generate method in TaskModelForTextGeneration, which was unavailable due to the upgrade of the transformers library to version 4.24.0 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10791805 --- modelscope/models/nlp/task_models/text_generation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py index b886f124..cd8e20cf 100644 --- a/modelscope/models/nlp/task_models/text_generation.py +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -2,7 +2,7 @@ from typing import Any, Dict import numpy as np -from transformers.modeling_utils import GenerationMixin +from transformers.modeling_utils import PreTrainedModel from modelscope.metainfo import TaskModels from modelscope.models.builder import MODELS @@ -17,8 +17,7 @@ __all__ = ['TaskModelForTextGeneration'] @MODELS.register_module( Tasks.text_generation, module_name=TaskModels.text_generation) -class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): - main_input_name = 'input_ids' +class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): def __init__(self, model_dir: str, *args, **kwargs): """initialize the text generation model from the `model_dir` path.