|
|
@@ -8,7 +8,7 @@ from ..builder import MODELS |
|
|
__all__ = ['PalmForTextGeneration'] |
|
|
__all__ = ['PalmForTextGeneration'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0) |
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.text_generation, module_name=Models.palm) |
|
|
class PalmForTextGeneration(Model): |
|
|
class PalmForTextGeneration(Model): |
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|