|
|
|
@@ -32,9 +32,9 @@ class PalmForTextGeneration(TorchModel): |
|
|
|
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), |
|
|
|
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), |
|
|
|
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) |
|
|
|
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>', |
|
|
|
''), |
|
|
|
('<s>', ''), ('</s>', ''), ('<unk>', ' ')) |
|
|
|
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), |
|
|
|
('<pad>', ''), ('<s>', ''), ('</s>', ''), |
|
|
|
('<unk>', ' '), ('<q>', '. ')) |
|
|
|
|
|
|
|
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] |
|
|
|
for _old, _new in replace_tokens_bert: |
|
|
|
|