添加 bloom 模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10509187
master
| @@ -80,6 +80,7 @@ class Models(object): | |||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| T5 = 'T5' | T5 = 'T5' | ||||
| bloom = 'bloom' | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from transformers import BloomConfig | |||||
| from transformers import BloomModel as BloomModelTransform | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import BACKBONES | |||||
| from modelscope.utils.constant import Fields | |||||
| @BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom) | |||||
| class BloomModel(BloomModelTransform): | |||||
| def __init__(self, **kwargs): | |||||
| config = BloomConfig(**kwargs) | |||||
| super().__init__(config) | |||||
| @@ -51,12 +51,9 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||||
| return addict.Dict(outputs) | return addict.Dict(outputs) | ||||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | ||||
| token_type_ids = kwargs.get('token_type_ids', None) | |||||
| # only last token for inputs_ids if past is defined in kwargs | # only last token for inputs_ids if past is defined in kwargs | ||||
| if past: | if past: | ||||
| input_ids = input_ids[:, -1].unsqueeze(-1) | input_ids = input_ids[:, -1].unsqueeze(-1) | ||||
| if token_type_ids is not None: | |||||
| token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |||||
| attention_mask = kwargs.get('attention_mask', None) | attention_mask = kwargs.get('attention_mask', None) | ||||
| position_ids = kwargs.get('position_ids', None) | position_ids = kwargs.get('position_ids', None) | ||||
| @@ -75,5 +72,8 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||||
| 'use_cache': kwargs.get('use_cache'), | 'use_cache': kwargs.get('use_cache'), | ||||
| 'position_ids': position_ids, | 'position_ids': position_ids, | ||||
| 'attention_mask': attention_mask, | 'attention_mask': attention_mask, | ||||
| 'token_type_ids': token_type_ids, | |||||
| } | } | ||||
| def generate(self, inputs, *args, **kwargs): | |||||
| input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs | |||||
| return super().generate(input_ids, *args, **kwargs) | |||||
| @@ -79,7 +79,7 @@ class TextGenerationPipeline(Pipeline): | |||||
| return self.model.generate(inputs, **forward_params) | return self.model.generate(inputs, **forward_params) | ||||
| def sentence_piece(self, inputs) -> Dict[str, Tensor]: | def sentence_piece(self, inputs) -> Dict[str, Tensor]: | ||||
| return self.preprocessor.tokenizer.decode(inputs.tolist())[0] | |||||
| return self.preprocessor.tokenizer.decode(inputs.tolist()[0]) | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | def postprocess(self, inputs: Dict[str, Tensor], | ||||
| **postprocess_params) -> Dict[str, str]: | **postprocess_params) -> Dict[str, str]: | ||||
| @@ -182,6 +182,12 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| max_length=20, | max_length=20, | ||||
| repetition_penalty=0.5)) | repetition_penalty=0.5)) | ||||
| @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||||
| def test_bloom(self): | |||||
| pipe = pipeline( | |||||
| task=Tasks.text_generation, model='Langboat/bloom-1b4-zh') | |||||
| print(pipe('中国的首都是')) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||