hemu.zp yingda.chen 3 years ago
parent
commit
ffd834fc25
5 changed files with 27 additions and 5 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +15
    -0
      modelscope/models/nlp/bloom/backbone.py
  3. +4
    -4
      modelscope/models/nlp/task_models/text_generation.py
  4. +1
    -1
      modelscope/pipelines/nlp/text_generation_pipeline.py
  5. +6
    -0
      tests/pipelines/test_text_generation.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -80,6 +80,7 @@ class Models(object):
bert_for_ds = 'bert-for-document-segmentation'
ponet = 'ponet'
T5 = 'T5'
bloom = 'bloom'

# audio models
sambert_hifigan = 'sambert-hifigan'


+ 15
- 0
modelscope/models/nlp/bloom/backbone.py View File

@@ -0,0 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from transformers import BloomConfig
from transformers import BloomModel as BloomModelTransform

from modelscope.metainfo import Models
from modelscope.models.builder import BACKBONES
from modelscope.utils.constant import Fields


@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom)
class BloomModel(BloomModelTransform):

def __init__(self, **kwargs):
config = BloomConfig(**kwargs)
super().__init__(config)

+ 4
- 4
modelscope/models/nlp/task_models/text_generation.py View File

@@ -51,12 +51,9 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):
return addict.Dict(outputs)

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get('token_type_ids', None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get('attention_mask', None)
position_ids = kwargs.get('position_ids', None)
@@ -75,5 +72,8 @@ class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel):
'use_cache': kwargs.get('use_cache'),
'position_ids': position_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
}

def generate(self, inputs, *args, **kwargs):
input_ids = inputs['input_ids'] if isinstance(inputs, Dict) else inputs
return super().generate(input_ids, *args, **kwargs)

+ 1
- 1
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -79,7 +79,7 @@ class TextGenerationPipeline(Pipeline):
return self.model.generate(inputs, **forward_params)

def sentence_piece(self, inputs) -> Dict[str, Tensor]:
return self.preprocessor.tokenizer.decode(inputs.tolist())[0]
return self.preprocessor.tokenizer.decode(inputs.tolist()[0])

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:


+ 6
- 0
tests/pipelines/test_text_generation.py View File

@@ -182,6 +182,12 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
max_length=20,
repetition_penalty=0.5))

@unittest.skip("Langboat's checkpoint has not been uploaded to modelhub")
def test_bloom(self):
pipe = pipeline(
task=Tasks.text_generation, model='Langboat/bloom-1b4-zh')
print(pipe('中国的首都是'))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save