将 Palm 中文模型接入 MaaS,添加了文本生成 pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8934393
* add text_generation model and pipeline
* fix bug
* fix bug
* add TextGenerator in pipeline
* fix bug
* update checkpoint and test inputs
* remove magic number..
* fix bug
* adjust code with AutoModel
* clear comments and tidy up the code
* move model.eval() into generator
* update master interface and lint code
* replace 'palm-text-generation' with 'palm'
* add text_generation model and pipeline
* fix bug
* fix bug
* add TextGenerator in pipeline
* fix bug
* fix conflict of pipeline.txt
* remove magic number..
* fix bug
* adjust code with AutoModel
* clear comments and tidy up the code
* move model.eval() into generator
* fix conflict
* replace 'palm-text-generation' with 'palm'
* fix conflict
* add test_run_modelhub
* update sofa version
* modify sofa version
* add test_run_with_model_name
* fix bug
master
| @@ -1 +1,2 @@ | |||||
| from .sequence_classification_model import * # noqa F403 | from .sequence_classification_model import * # noqa F403 | ||||
| from .text_generation_model import * # noqa F403 | |||||
| @@ -0,0 +1,52 @@ | |||||
| from typing import Any, Dict | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from ..base import Model, Tensor | |||||
| from ..builder import MODELS | |||||
| __all__ = ['PalmForTextGenerationModel'] | |||||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm') | |||||
| class PalmForTextGenerationModel(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None. | |||||
| """ | |||||
| from sofa import PalmTokenizer | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model_dir = model_dir | |||||
| from sofa.models.palm import PalmForConditionalGeneration, TextGenerator | |||||
| tokenizer = kwargs.pop('tokenizer', | |||||
| PalmTokenizer.from_pretrained(model_dir)) | |||||
| model = PalmForConditionalGeneration.from_pretrained(model_dir) | |||||
| self.generator = TextGenerator(model, tokenizer) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| encoder_inputs = [ | |||||
| input['input_ids'], input['token_type_ids'], | |||||
| input['attention_mask'] | |||||
| ] | |||||
| return self.generator(encoder_inputs) | |||||
| @@ -1 +1,2 @@ | |||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .text_generation_pipeline import * # noqa F403 | |||||
| @@ -0,0 +1,59 @@ | |||||
| from typing import Dict, Optional, Union | |||||
| from maas_lib.models import Model | |||||
| from maas_lib.models.nlp import PalmForTextGenerationModel | |||||
| from maas_lib.preprocessors import TextGenerationPreprocessor | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | |||||
| from ..builder import PIPELINES | |||||
| __all__ = ['TextGenerationPipeline'] | |||||
| @PIPELINES.register_module(Tasks.text_generation, module_name=r'palm') | |||||
| class TextGenerationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[PalmForTextGenerationModel, str], | |||||
| preprocessor: Optional[TextGenerationPreprocessor] = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| """ | |||||
| sc_model = model if isinstance( | |||||
| model, | |||||
| PalmForTextGenerationModel) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = TextGenerationPreprocessor( | |||||
| sc_model.model_dir, | |||||
| first_sequence='sentence', | |||||
| second_sequence=None) | |||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| vocab_size = len(self.tokenizer.vocab) | |||||
| pred_list = inputs['predictions'] | |||||
| pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||||
| for j in range(len(pred_ids)): | |||||
| if pred_ids[j] >= vocab_size: | |||||
| pred_ids[j] = 100 | |||||
| pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||||
| pred_string = ''.join(pred).replace( | |||||
| '##', | |||||
| '').split('[SEP]')[0].replace('[CLS]', | |||||
| '').replace('[SEP]', | |||||
| '').replace('[UNK]', '') | |||||
| return {'pred_string': pred_string} | |||||
| @@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor | |||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .nlp import TextGenerationPreprocessor | |||||
| @@ -89,3 +89,61 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | rst['token_type_ids'].append(feature['token_type_ids']) | ||||
| return rst | return rst | ||||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm') | |||||
| class TextGenerationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data using the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| from sofa import PalmTokenizer | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.second_sequence: str = kwargs.pop('second_sequence', | |||||
| 'second_sequence') | |||||
| self.sequence_length: int = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = PalmTokenizer.from_pretrained(model_dir) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| import torch | |||||
| new_data = {self.first_sequence: data} | |||||
| # preprocess the data for the model input | |||||
| rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data.get(self.first_sequence, None) | |||||
| text_b = new_data.get(self.second_sequence, None) | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| text_b, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=max_seq_length) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||||
| @@ -1,2 +1,3 @@ | |||||
| -r requirements/runtime.txt | -r requirements/runtime.txt | ||||
| -r requirements/pipeline.txt | -r requirements/pipeline.txt | ||||
| -r requirements/nlp.txt | |||||
| @@ -0,0 +1 @@ | |||||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.1.3-py3-none-any.whl | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from maas_lib.models import Model | |||||
| from maas_lib.models.nlp import PalmForTextGenerationModel | |||||
| from maas_lib.pipelines import TextGenerationPipeline, pipeline | |||||
| from maas_lib.preprocessors import TextGenerationPreprocessor | |||||
| from maas_lib.utils.constant import Tasks | |||||
| class TextGenerationTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_palm_text-generation_chinese' | |||||
| input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | |||||
| input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = TextGenerationPreprocessor( | |||||
| cache_path, first_sequence='sentence', second_sequence=None) | |||||
| model = PalmForTextGenerationModel( | |||||
| cache_path, tokenizer=preprocessor.tokenizer) | |||||
| pipeline1 = TextGenerationPipeline(model, preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.text_generation, model=model, preprocessor=preprocessor) | |||||
| print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}') | |||||
| print() | |||||
| print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = TextGenerationPreprocessor( | |||||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text_generation, model=model, preprocessor=preprocessor) | |||||
| print(pipeline_ins(self.input1)) | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text_generation, model=self.model_id) | |||||
| print(pipeline_ins(self.input2)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||