From ffd834fc258c02450064f5c8c4df6f0389226a4c Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Tue, 25 Oct 2022 12:58:02 +0800 Subject: [PATCH] [to #42322933] Add bloom model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 bloom 模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10509187 --- modelscope/metainfo.py | 1 + modelscope/models/nlp/bloom/backbone.py | 15 +++++++++++++++ .../models/nlp/task_models/text_generation.py | 8 ++++---- .../pipelines/nlp/text_generation_pipeline.py | 2 +- tests/pipelines/test_text_generation.py | 6 ++++++ 5 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 modelscope/models/nlp/bloom/backbone.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f77ff299..16190eb8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -80,6 +80,7 @@ class Models(object): bert_for_ds = 'bert-for-document-segmentation' ponet = 'ponet' T5 = 'T5' + bloom = 'bloom' # audio models sambert_hifigan = 'sambert-hifigan' diff --git a/modelscope/models/nlp/bloom/backbone.py b/modelscope/models/nlp/bloom/backbone.py new file mode 100644 index 00000000..b6bd315e --- /dev/null +++ b/modelscope/models/nlp/bloom/backbone.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import BloomConfig +from transformers import BloomModel as BloomModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Fields + + +@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom) +class BloomModel(BloomModelTransform): + + def __init__(self, **kwargs): + config = BloomConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py index 973198ae..f17b0f6b 100644 --- a/modelscope/models/nlp/task_models/text_generation.py +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -51,12 +51,9 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): return addict.Dict(outputs) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get('token_type_ids', None) # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get('attention_mask', None) position_ids = kwargs.get('position_ids', None) @@ -75,5 +72,8 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): 'use_cache': kwargs.get('use_cache'), 'position_ids': position_ids, 'attention_mask': attention_mask, - 'token_type_ids': token_type_ids, } + + def generate(self, inputs, *args, **kwargs): + input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs + return super().generate(input_ids, *args, **kwargs) diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index ae92f26a..28acebb4 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -79,7 +79,7 @@ class TextGenerationPipeline(Pipeline): return self.model.generate(inputs, **forward_params) def sentence_piece(self, inputs) -> Dict[str, Tensor]: - return self.preprocessor.tokenizer.decode(inputs.tolist())[0] + return self.preprocessor.tokenizer.decode(inputs.tolist()[0]) def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 4b0ebd47..f624f021 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -182,6 +182,12 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): max_length=20, repetition_penalty=0.5)) + @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") + def test_bloom(self): + pipe = pipeline( + task=Tasks.text_generation, model='Langboat/bloom-1b4-zh') + print(pipe('中国的首都是')) + if __name__ == '__main__': unittest.main()