|
|
|
@@ -2,7 +2,7 @@ |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from transformers.modeling_utils import GenerationMixin |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
|
|
from modelscope.metainfo import TaskModels |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
@@ -17,8 +17,7 @@ __all__ = ['TaskModelForTextGeneration'] |
|
|
|
|
|
|
|
@MODELS.register_module( |
|
|
|
Tasks.text_generation, module_name=TaskModels.text_generation) |
|
|
|
class TaskModelForTextGeneration(SingleBackboneTaskModelBase, GenerationMixin): |
|
|
|
main_input_name = 'input_ids' |
|
|
|
class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): |
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|
"""initialize the text generation model from the `model_dir` path. |
|
|
|
|