diff --git a/maas_lib/models/nlp/__init__.py b/maas_lib/models/nlp/__init__.py index d85c0ba7..b2a1d43b 100644 --- a/maas_lib/models/nlp/__init__.py +++ b/maas_lib/models/nlp/__init__.py @@ -1 +1,2 @@ from .sequence_classification_model import * # noqa F403 +from .text_generation_model import * # noqa F403 diff --git a/maas_lib/models/nlp/text_generation_model.py b/maas_lib/models/nlp/text_generation_model.py new file mode 100644 index 00000000..04345d22 --- /dev/null +++ b/maas_lib/models/nlp/text_generation_model.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from maas_lib.utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['PalmForTextGenerationModel'] + + +@MODELS.register_module(Tasks.text_generation, module_name=r'palm') +class PalmForTextGenerationModel(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + from sofa import PalmTokenizer + + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + + from sofa.models.palm import PalmForConditionalGeneration, TextGenerator + tokenizer = kwargs.pop('tokenizer', + PalmTokenizer.from_pretrained(model_dir)) + model = PalmForConditionalGeneration.from_pretrained(model_dir) + self.generator = TextGenerator(model, tokenizer) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + + encoder_inputs = [ + input['input_ids'], input['token_type_ids'], + input['attention_mask'] + ] + return self.generator(encoder_inputs) diff --git a/maas_lib/pipelines/nlp/__init__.py b/maas_lib/pipelines/nlp/__init__.py index f9d874e7..3dbbc1bb 100644 --- a/maas_lib/pipelines/nlp/__init__.py +++ b/maas_lib/pipelines/nlp/__init__.py @@ -1 +1,2 @@ from .sequence_classification_pipeline import * # noqa F403 +from .text_generation_pipeline import * # noqa F403 diff --git a/maas_lib/pipelines/nlp/text_generation_pipeline.py b/maas_lib/pipelines/nlp/text_generation_pipeline.py new file mode 100644 index 00000000..865557b5 --- /dev/null +++ b/maas_lib/pipelines/nlp/text_generation_pipeline.py @@ -0,0 +1,59 @@ +from typing import Dict, Optional, Union + +from maas_lib.models import Model +from maas_lib.models.nlp import PalmForTextGenerationModel +from maas_lib.preprocessors import TextGenerationPreprocessor +from maas_lib.utils.constant import Tasks +from ..base import Pipeline, Tensor +from ..builder import PIPELINES + +__all__ = ['TextGenerationPipeline'] + + +@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm') +class TextGenerationPipeline(Pipeline): + + def __init__(self, + model: Union[PalmForTextGenerationModel, str], + preprocessor: Optional[TextGenerationPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SequenceClassificationModel): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + sc_model = model if isinstance( + model, + PalmForTextGenerationModel) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TextGenerationPreprocessor( + sc_model.model_dir, + first_sequence='sentence', + second_sequence=None) + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + + def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + vocab_size = len(self.tokenizer.vocab) + pred_list = inputs['predictions'] + pred_ids = pred_list[0][0].cpu().numpy().tolist() + for j in range(len(pred_ids)): + if pred_ids[j] >= vocab_size: + pred_ids[j] = 100 + pred = self.tokenizer.convert_ids_to_tokens(pred_ids) + pred_string = ''.join(pred).replace( + '##', + '').split('[SEP]')[0].replace('[CLS]', + '').replace('[SEP]', + '').replace('[UNK]', '') + return {'pred_string': pred_string} diff --git a/maas_lib/preprocessors/__init__.py b/maas_lib/preprocessors/__init__.py index 81ca1007..518ea977 100644 --- a/maas_lib/preprocessors/__init__.py +++ b/maas_lib/preprocessors/__init__.py @@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 +from .nlp import TextGenerationPreprocessor diff --git a/maas_lib/preprocessors/nlp.py b/maas_lib/preprocessors/nlp.py index bde401c2..176322d4 100644 --- a/maas_lib/preprocessors/nlp.py +++ b/maas_lib/preprocessors/nlp.py @@ -89,3 +89,61 @@ class SequenceClassificationPreprocessor(Preprocessor): rst['token_type_ids'].append(feature['token_type_ids']) return rst + + +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm') +class TextGenerationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data using the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + from sofa import PalmTokenizer + + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'first_sequence') + self.second_sequence: str = kwargs.pop('second_sequence', + 'second_sequence') + self.sequence_length: int = kwargs.pop('sequence_length', 128) + self.tokenizer = PalmTokenizer.from_pretrained(model_dir) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + import torch + + new_data = {self.first_sequence: data} + # preprocess the data for the model input + + rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} + + max_seq_length = self.sequence_length + + text_a = new_data.get(self.first_sequence, None) + text_b = new_data.get(self.second_sequence, None) + feature = self.tokenizer( + text_a, + text_b, + padding='max_length', + truncation=True, + max_length=max_seq_length) + + rst['input_ids'].append(feature['input_ids']) + rst['attention_mask'].append(feature['attention_mask']) + rst['token_type_ids'].append(feature['token_type_ids']) + + return {k: torch.tensor(v) for k, v in rst.items()} diff --git a/requirements.txt b/requirements.txt index 999c567e..3cc6857e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -r requirements/runtime.txt -r requirements/pipeline.txt +-r requirements/nlp.txt diff --git a/requirements/nlp.txt b/requirements/nlp.txt new file mode 100644 index 00000000..8de83798 --- /dev/null +++ b/requirements/nlp.txt @@ -0,0 +1 @@ +https://alinlp.alibaba-inc.com/pypi/sofa-1.0.1.3-py3-none-any.whl diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py new file mode 100644 index 00000000..d59fdabb --- /dev/null +++ b/tests/pipelines/test_text_generation.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from maas_lib.models import Model +from maas_lib.models.nlp import PalmForTextGenerationModel +from maas_lib.pipelines import TextGenerationPipeline, pipeline +from maas_lib.preprocessors import TextGenerationPreprocessor +from maas_lib.utils.constant import Tasks + + +class TextGenerationTest(unittest.TestCase): + model_id = 'damo/nlp_palm_text-generation_chinese' + input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" + input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" + + def test_run(self): + cache_path = snapshot_download(self.model_id) + preprocessor = TextGenerationPreprocessor( + cache_path, first_sequence='sentence', second_sequence=None) + model = PalmForTextGenerationModel( + cache_path, tokenizer=preprocessor.tokenizer) + pipeline1 = TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_generation, model=model, preprocessor=preprocessor) + print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}') + print() + print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = TextGenerationPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.text_generation, model=model, preprocessor=preprocessor) + print(pipeline_ins(self.input1)) + + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_generation, model=self.model_id) + print(pipeline_ins(self.input2)) + + +if __name__ == '__main__': + unittest.main()