接入支持中英文的 Palm2.0 模型,复用 text-generation-pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9066550master
| @@ -7,7 +7,7 @@ from ..builder import MODELS | |||
| __all__ = ['PalmForTextGeneration'] | |||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm') | |||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||
| class PalmForTextGeneration(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -18,35 +18,26 @@ class PalmForTextGeneration(Model): | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| from sofa import PalmTokenizer | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| from sofa.models.palm import PalmForConditionalGeneration, TextGenerator | |||
| tokenizer = kwargs.pop('tokenizer', | |||
| PalmTokenizer.from_pretrained(model_dir)) | |||
| from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator | |||
| model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||
| self.generator = TextGenerator(model, tokenizer) | |||
| self.tokenizer = model.tokenizer | |||
| self.generator = Translator(model) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer | |||
| } | |||
| """ | |||
| encoder_inputs = [ | |||
| input['input_ids'], input['token_type_ids'], | |||
| input['attention_mask'] | |||
| ] | |||
| return self.generator(encoder_inputs) | |||
| return self.generator(**input) | |||
| @@ -22,7 +22,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||
| Tasks.text_generation: ('palm2.0', | |||
| 'damo/nlp_palm2.0_text-generation_chinese-base'), | |||
| Tasks.image_captioning: ('ofa', None), | |||
| Tasks.image_generation: | |||
| ('person-image-cartoon', | |||
| @@ -10,7 +10,7 @@ from ..builder import PIPELINES | |||
| __all__ = ['TextGenerationPipeline'] | |||
| @PIPELINES.register_module(Tasks.text_generation, module_name=r'palm') | |||
| @PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||
| class TextGenerationPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -23,15 +23,16 @@ class TextGenerationPipeline(Pipeline): | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| sc_model = model if isinstance( | |||
| model = model if isinstance( | |||
| model, PalmForTextGeneration) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationPreprocessor( | |||
| sc_model.model_dir, | |||
| model.model_dir, | |||
| model.tokenizer, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = model.tokenizer | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| @@ -42,17 +43,20 @@ class TextGenerationPipeline(Pipeline): | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
| replace_tokens_roberta = ((r' +', ' '), ('<mask>', '<q>'), ('<pad>', | |||
| ''), | |||
| ('<s>', ''), ('</s>', ''), ('<unk>', ' ')) | |||
| vocab_size = len(self.tokenizer.vocab) | |||
| pred_list = inputs['predictions'] | |||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||
| for j in range(len(pred_ids)): | |||
| if pred_ids[j] >= vocab_size: | |||
| pred_ids[j] = 100 | |||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||
| pred_string = ''.join(pred).replace( | |||
| '##', | |||
| '').split('[SEP]')[0].replace('[CLS]', | |||
| '').replace('[SEP]', | |||
| '').replace('[UNK]', '') | |||
| pred_string = self.tokenizer.decode(pred_ids) | |||
| for _old, _new in replace_tokens_bert: | |||
| pred_string = pred_string.replace(_old, _new) | |||
| pred_string.strip() | |||
| for _old, _new in replace_tokens_roberta: | |||
| pred_string = pred_string.replace(_old, _new) | |||
| pred_string.strip() | |||
| return {'text': pred_string} | |||
| @@ -115,17 +115,15 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| return rst | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm') | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0') | |||
| class TextGenerationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| def __init__(self, model_dir: str, tokenizer, *args, **kwargs): | |||
| """preprocess the data using the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| from sofa import PalmTokenizer | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| @@ -134,7 +132,7 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| self.second_sequence: str = kwargs.pop('second_sequence', | |||
| 'second_sequence') | |||
| self.sequence_length: int = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = PalmTokenizer.from_pretrained(model_dir) | |||
| self.tokenizer = tokenizer | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @@ -153,7 +151,7 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| new_data = {self.first_sequence: data} | |||
| # preprocess the data for the model input | |||
| rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} | |||
| rst = {'input_ids': [], 'attention_mask': []} | |||
| max_seq_length = self.sequence_length | |||
| @@ -168,7 +166,6 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @@ -1 +1 @@ | |||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.1.3-py3-none-any.whl | |||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl | |||
| @@ -1,7 +1,7 @@ | |||
| addict | |||
| datasets | |||
| easydict | |||
| https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.2.dev0-py3-none-any.whl | |||
| https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl | |||
| numpy | |||
| opencv-python-headless | |||
| Pillow>=6.2.0 | |||
| @@ -12,43 +12,67 @@ from modelscope.utils.test_utils import test_level | |||
| class TextGenerationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_palm_text-generation_chinese' | |||
| input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | |||
| input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | |||
| model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' | |||
| model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' | |||
| input_zh = """ | |||
| 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: | |||
| 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 | |||
| """ | |||
| input_en = """ | |||
| The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started | |||
| her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , | |||
| 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges | |||
| of paedophilia against nine children because he has dementia . Today , newly-released documents | |||
| revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . | |||
| And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her | |||
| pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . | |||
| """ | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| cache_path, first_sequence='sentence', second_sequence=None) | |||
| model = PalmForTextGeneration( | |||
| cache_path, tokenizer=preprocessor.tokenizer) | |||
| pipeline1 = TextGenerationPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.text_generation, model=model, preprocessor=preprocessor) | |||
| print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}') | |||
| print() | |||
| print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') | |||
| for model_id, input in ((self.model_id_zh, self.input_zh), | |||
| (self.model_id_en, self.input_en)): | |||
| cache_path = snapshot_download(model_id) | |||
| model = PalmForTextGeneration(cache_path) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| cache_path, | |||
| model.tokenizer, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| pipeline1 = TextGenerationPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.text_generation, model=model, preprocessor=preprocessor) | |||
| print( | |||
| f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_generation, model=model, preprocessor=preprocessor) | |||
| print(pipeline_ins(self.input1)) | |||
| for model_id, input in ((self.model_id_zh, self.input_zh), | |||
| (self.model_id_en, self.input_en)): | |||
| model = Model.from_pretrained(model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| model.model_dir, | |||
| model.tokenizer, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_generation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_ins(input)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_generation, model=self.model_id) | |||
| print(pipeline_ins(self.input2)) | |||
| for model_id, input in ((self.model_id_zh, self.input_zh), | |||
| (self.model_id_en, self.input_en)): | |||
| pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) | |||
| print(pipeline_ins(input)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.text_generation) | |||
| print(pipeline_ins(self.input2)) | |||
| print(pipeline_ins(self.input_zh)) | |||
| if __name__ == '__main__': | |||