From 59c5dd8dfe053c52534b43887b6bee05639995d9 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Wed, 10 Aug 2022 13:46:23 +0800 Subject: [PATCH] [to #42322933] remove sep token at the end of tokenizer output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generate 时去除 tokenizer 输出结尾的 sep,修复 gpt3 模型目前续写内容与上文无关的 bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9696570 --- modelscope/models/nlp/gpt3/gpt3_for_text_generation.py | 2 ++ tests/pipelines/test_text_generation.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py index 9f6874c6..7cff9ad4 100644 --- a/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py +++ b/modelscope/models/nlp/gpt3/gpt3_for_text_generation.py @@ -48,6 +48,8 @@ class GPT3ForTextGeneration(TorchModel): attention_mask = input['attention_mask'] input_ids = input_ids[0][attention_mask[0].nonzero()] \ .squeeze().unsqueeze(0) + # remove sep token at the end of tokenizer output + input_ids = input_ids[:, :-1] gen_params = dict() gen_params['inputs'] = input_ids diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index cebc80ae..c08209a4 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -32,7 +32,7 @@ class TextGenerationTest(unittest.TestCase): self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base' self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large' - self.gpt3_input = '我很好奇' + self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,' def run_pipeline_with_model_instance(self, model_id, input): model = Model.from_pretrained(model_id)