Browse Source

[to #42322933] remove sep token at the end of tokenizer output

generate 时去除 tokenizer 输出结尾的 sep,修复 gpt3 模型目前续写内容与上文无关的 bug
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9696570
master
hemu.zp yingda.chen 3 years ago
parent
commit
59c5dd8dfe
2 changed files with 3 additions and 1 deletions
  1. +2
    -0
      modelscope/models/nlp/gpt3/gpt3_for_text_generation.py
  2. +1
    -1
      tests/pipelines/test_text_generation.py

+ 2
- 0
modelscope/models/nlp/gpt3/gpt3_for_text_generation.py View File

@@ -48,6 +48,8 @@ class GPT3ForTextGeneration(TorchModel):
attention_mask = input['attention_mask']
input_ids = input_ids[0][attention_mask[0].nonzero()] \
.squeeze().unsqueeze(0)
# remove sep token at the end of tokenizer output
input_ids = input_ids[:, :-1]

gen_params = dict()
gen_params['inputs'] = input_ids


+ 1
- 1
tests/pipelines/test_text_generation.py View File

@@ -32,7 +32,7 @@ class TextGenerationTest(unittest.TestCase):

self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base'
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
self.gpt3_input = '我很好奇'
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,'

def run_pipeline_with_model_instance(self, model_id, input):
model = Model.from_pretrained(model_id)


Loading…
Cancel
Save