|
|
|
@@ -38,3 +38,14 @@ class SentencePiecePreprocessor(Preprocessor): |
|
|
|
|
|
|
|
def __call__(self, data: str) -> torch.Tensor: |
|
|
|
return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) |
|
|
|
|
|
|
|
def decode(self, tokens, **kwargs): |
|
|
|
"""Decode the tokens to real text. |
|
|
|
|
|
|
|
Args: |
|
|
|
tokens: The output tokens from model's `forward` and `generate` |
|
|
|
|
|
|
|
Returns: |
|
|
|
The actual text. |
|
|
|
""" |
|
|
|
return self.tokenizer.decode(tokens) |