Browse Source

[to #42322933] Add Palm2.0 model.

接入支持中英文的 Palm2.0 模型,复用 text-generation-pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9066550
master
hemu.zp 3 years ago
parent
commit
99fb503695
7 changed files with 83 additions and 66 deletions
  1. +8
    -17
      modelscope/models/nlp/palm_for_text_generation.py
  2. +2
    -1
      modelscope/pipelines/builder.py
  3. +19
    -15
      modelscope/pipelines/nlp/text_generation_pipeline.py
  4. +4
    -7
      modelscope/preprocessors/nlp.py
  5. +1
    -1
      requirements/nlp.txt
  6. +1
    -1
      requirements/runtime.txt
  7. +48
    -24
      tests/pipelines/test_text_generation.py

+ 8
- 17
modelscope/models/nlp/palm_for_text_generation.py View File

@@ -7,7 +7,7 @@ from ..builder import MODELS
__all__ = ['PalmForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=r'palm')
@MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0')
class PalmForTextGeneration(Model):

def __init__(self, model_dir: str, *args, **kwargs):
@@ -18,35 +18,26 @@ class PalmForTextGeneration(Model):
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
from sofa import PalmTokenizer

super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

from sofa.models.palm import PalmForConditionalGeneration, TextGenerator
tokenizer = kwargs.pop('tokenizer',
PalmTokenizer.from_pretrained(model_dir))
from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator
model = PalmForConditionalGeneration.from_pretrained(model_dir)
self.generator = TextGenerator(model, tokenizer)
self.tokenizer = model.tokenizer
self.generator = Translator(model)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Dict[str, Tensor]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer
}
"""

encoder_inputs = [
input['input_ids'], input['token_type_ids'],
input['attention_mask']
]
return self.generator(encoder_inputs)
return self.generator(**input)

+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -22,7 +22,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.text_generation: ('palm2.0',
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.image_captioning: ('ofa', None),
Tasks.image_generation:
('person-image-cartoon',


+ 19
- 15
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -10,7 +10,7 @@ from ..builder import PIPELINES
__all__ = ['TextGenerationPipeline']


@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm')
@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0')
class TextGenerationPipeline(Pipeline):

def __init__(self,
@@ -23,15 +23,16 @@ class TextGenerationPipeline(Pipeline):
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""
sc_model = model if isinstance(
model = model if isinstance(
model, PalmForTextGeneration) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
sc_model.model_dir,
model.model_dir,
model.tokenizer,
first_sequence='sentence',
second_sequence=None)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
self.tokenizer = preprocessor.tokenizer
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = model.tokenizer

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
"""process the prediction results
@@ -42,17 +43,20 @@ class TextGenerationPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>',
''),
('<s>', ''), ('</s>', ''), ('<unk>', ' '))

vocab_size = len(self.tokenizer.vocab)
pred_list = inputs['predictions']
pred_ids = pred_list[0][0].cpu().numpy().tolist()
for j in range(len(pred_ids)):
if pred_ids[j] >= vocab_size:
pred_ids[j] = 100
pred = self.tokenizer.convert_ids_to_tokens(pred_ids)
pred_string = ''.join(pred).replace(
'##',
'').split('[SEP]')[0].replace('[CLS]',
'').replace('[SEP]',
'').replace('[UNK]', '')
pred_string = self.tokenizer.decode(pred_ids)
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
for _old, _new in replace_tokens_roberta:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
return {'text': pred_string}

+ 4
- 7
modelscope/preprocessors/nlp.py View File

@@ -115,17 +115,15 @@ class SequenceClassificationPreprocessor(Preprocessor):
return rst


@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm')
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0')
class TextGenerationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
"""preprocess the data using the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
from sofa import PalmTokenizer

super().__init__(*args, **kwargs)

self.model_dir: str = model_dir
@@ -134,7 +132,7 @@ class TextGenerationPreprocessor(Preprocessor):
self.second_sequence: str = kwargs.pop('second_sequence',
'second_sequence')
self.sequence_length: int = kwargs.pop('sequence_length', 128)
self.tokenizer = PalmTokenizer.from_pretrained(model_dir)
self.tokenizer = tokenizer

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@@ -153,7 +151,7 @@ class TextGenerationPreprocessor(Preprocessor):
new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
rst = {'input_ids': [], 'attention_mask': []}

max_seq_length = self.sequence_length

@@ -168,7 +166,6 @@ class TextGenerationPreprocessor(Preprocessor):

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])

return {k: torch.tensor(v) for k, v in rst.items()}



+ 1
- 1
requirements/nlp.txt View File

@@ -1 +1 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.1.3-py3-none-any.whl
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl

+ 1
- 1
requirements/runtime.txt View File

@@ -1,7 +1,7 @@
addict
datasets
easydict
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.2.dev0-py3-none-any.whl
https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl
numpy
opencv-python-headless
Pillow>=6.2.0


+ 48
- 24
tests/pipelines/test_text_generation.py View File

@@ -12,43 +12,67 @@ from modelscope.utils.test_utils import test_level


class TextGenerationTest(unittest.TestCase):
model_id = 'damo/nlp_palm_text-generation_chinese'
input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'"
input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'"
model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
input_zh = """
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
"""
input_en = """
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
of paedophilia against nine children because he has dementia . Today , newly-released documents
revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
"""

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = TextGenerationPreprocessor(
cache_path, first_sequence='sentence', second_sequence=None)
model = PalmForTextGeneration(
cache_path, tokenizer=preprocessor.tokenizer)
pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_generation, model=model, preprocessor=preprocessor)
print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}')
print()
print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}')
for model_id, input in ((self.model_id_zh, self.input_zh),
(self.model_id_en, self.input_en)):
cache_path = snapshot_download(model_id)
model = PalmForTextGeneration(cache_path)
preprocessor = TextGenerationPreprocessor(
cache_path,
model.tokenizer,
first_sequence='sentence',
second_sequence=None)
pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_generation, model=model, preprocessor=preprocessor)
print(
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = TextGenerationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model, preprocessor=preprocessor)
print(pipeline_ins(self.input1))
for model_id, input in ((self.model_id_zh, self.input_zh),
(self.model_id_en, self.input_en)):
model = Model.from_pretrained(model_id)
preprocessor = TextGenerationPreprocessor(
model.model_dir,
model.tokenizer,
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.text_generation,
model=model,
preprocessor=preprocessor)
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.text_generation, model=self.model_id)
print(pipeline_ins(self.input2))
for model_id, input in ((self.model_id_zh, self.input_zh),
(self.model_id_en, self.input_en)):
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
print(pipeline_ins(input))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_generation)
print(pipeline_ins(self.input2))
print(pipeline_ins(self.input_zh))


if __name__ == '__main__':


Loading…
Cancel
Save