From 06486d00274c064e3b7005073edc5a40354ef3ec Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 4 Aug 2022 16:34:34 +0800 Subject: [PATCH] [to #42322933] Fix bug for palm model postprecessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复 palm 模型后处理不完善的问题,使输出结果中不再有无意义的字符 2. 修复 palm 模型对短文本生成摘要过长的问题,修改了 modelhub 中 config.json 的 min_length 参数 3. 去除 generate 过程中无意义的 log --- modelscope/models/nlp/palm_v2/modeling_palm.py | 1 - modelscope/models/nlp/palm_v2/palm_for_text_generation.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/modelscope/models/nlp/palm_v2/modeling_palm.py b/modelscope/models/nlp/palm_v2/modeling_palm.py index c2121cfd..127b5440 100644 --- a/modelscope/models/nlp/palm_v2/modeling_palm.py +++ b/modelscope/models/nlp/palm_v2/modeling_palm.py @@ -1170,7 +1170,6 @@ class Translator(nn.Module): results['batch'] = batch for step in range(max_length): - self.logger.info(f'step: {step + 1} / {max_length}') decoder_input = alive_seq[:, -1].view(1, -1) # Decoder forward. diff --git a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py index 7f8e918b..e432cc58 100644 --- a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py @@ -32,9 +32,9 @@ class PalmForTextGeneration(TorchModel): replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) - replace_tokens_roberta = ((r' +', ' '), ('', ''), ('', - ''), - ('', ''), ('', ''), ('', ' ')) + replace_tokens_roberta = ((r' +', ' '), ('', '. '), + ('', ''), ('', ''), ('', ''), + ('', ' '), ('', '. ')) strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] for _old, _new in replace_tokens_bert: