Browse Source

[to #42322933] Fix bug for palm model postprecessor

1. 修复 palm 模型后处理不完善的问题,使输出结果中不再有无意义的字符
2. 修复 palm 模型对短文本生成摘要过长的问题,修改了 modelhub 中 config.json 的 min_length 参数
3. 去除 generate 过程中无意义的 log
master
hemu.zp 3 years ago
parent
commit
06486d0027
2 changed files with 3 additions and 4 deletions
  1. +0
    -1
      modelscope/models/nlp/palm_v2/modeling_palm.py
  2. +3
    -3
      modelscope/models/nlp/palm_v2/palm_for_text_generation.py

+ 0
- 1
modelscope/models/nlp/palm_v2/modeling_palm.py View File

@@ -1170,7 +1170,6 @@ class Translator(nn.Module):
results['batch'] = batch

for step in range(max_length):
self.logger.info(f'step: {step + 1} / {max_length}')
decoder_input = alive_seq[:, -1].view(1, -1)

# Decoder forward.


+ 3
- 3
modelscope/models/nlp/palm_v2/palm_for_text_generation.py View File

@@ -32,9 +32,9 @@ class PalmForTextGeneration(TorchModel):
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>',
''),
('<s>', ''), ('</s>', ''), ('<unk>', ' '))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '),
('<pad>', ''), ('<s>', ''), ('</s>', ''),
('<unk>', ' '), ('<q>', '. '))

strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list]
for _old, _new in replace_tokens_bert:


Loading…
Cancel
Save