|
|
|
@@ -26,6 +26,12 @@ class PalmForTextGeneration(Model): |
|
|
|
self.tokenizer = model.tokenizer |
|
|
|
self.generator = Translator(model) |
|
|
|
|
|
|
|
def train(self): |
|
|
|
return self.generator.train() |
|
|
|
|
|
|
|
def eval(self): |
|
|
|
return self.generator.eval() |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
|
|
"""return the result by the model |
|
|
|
|
|
|
|
|