1. 添加 gpt_neo 模型,因 checkpoint 归属于 Langboat 还未上传到模型库,已线下完成测试
2. 添加 text-generation task models 与 head,后续会将 gpt3,palm 等已上线文本生成模型统一为 backbone + head 结构的 task models
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10404249
master
| @@ -71,6 +71,7 @@ class Models(object): | |||||
| gcnncrf = 'gcnn-crf' | gcnncrf = 'gcnn-crf' | ||||
| bart = 'bart' | bart = 'bart' | ||||
| gpt3 = 'gpt3' | gpt3 = 'gpt3' | ||||
| gpt_neo = 'gpt-neo' | |||||
| plug = 'plug' | plug = 'plug' | ||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| @@ -101,6 +102,7 @@ class TaskModels(object): | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| text_generation = 'text-generation' | |||||
| class Heads(object): | class Heads(object): | ||||
| @@ -116,6 +118,8 @@ class Heads(object): | |||||
| token_classification = 'token-classification' | token_classification = 'token-classification' | ||||
| # extraction | # extraction | ||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| # text gen | |||||
| text_generation = 'text-generation' | |||||
| class Pipelines(object): | class Pipelines(object): | ||||
| @@ -341,6 +345,7 @@ class Preprocessors(object): | |||||
| re_tokenizer = 're-tokenizer' | re_tokenizer = 're-tokenizer' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| sentence_piece = 'sentence-piece' | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| @@ -30,7 +30,8 @@ if TYPE_CHECKING: | |||||
| InformationExtractionModel, | InformationExtractionModel, | ||||
| SequenceClassificationModel, | SequenceClassificationModel, | ||||
| SingleBackboneTaskModelBase, | SingleBackboneTaskModelBase, | ||||
| TokenClassificationModel) | |||||
| TokenClassificationModel, | |||||
| TaskModelForTextGeneration) | |||||
| from .token_classification import SbertForTokenClassification | from .token_classification import SbertForTokenClassification | ||||
| from .sentence_embedding import SentenceEmbedding | from .sentence_embedding import SentenceEmbedding | ||||
| from .passage_ranking import PassageRanking | from .passage_ranking import PassageRanking | ||||
| @@ -69,6 +70,7 @@ else: | |||||
| 'SequenceClassificationModel', | 'SequenceClassificationModel', | ||||
| 'SingleBackboneTaskModelBase', | 'SingleBackboneTaskModelBase', | ||||
| 'TokenClassificationModel', | 'TokenClassificationModel', | ||||
| 'TaskModelForTextGeneration', | |||||
| ], | ], | ||||
| 'token_classification': ['SbertForTokenClassification'], | 'token_classification': ['SbertForTokenClassification'], | ||||
| 'table_question_answering': ['TableQuestionAnswering'], | 'table_question_answering': ['TableQuestionAnswering'], | ||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from transformers import GPTNeoConfig | |||||
| from transformers import GPTNeoModel as GPTNeoModelTransform | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.builder import BACKBONES | |||||
| from modelscope.utils.constant import Fields | |||||
| @BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.gpt_neo) | |||||
| class GPTNeoModel(GPTNeoModelTransform): | |||||
| def __init__(self, **kwargs): | |||||
| config = GPTNeoConfig(**kwargs) | |||||
| super().__init__(config) | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Dict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| from modelscope.metainfo import Heads | |||||
| from modelscope.models.base import TorchHead | |||||
| from modelscope.models.builder import HEADS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| @HEADS.register_module( | |||||
| Tasks.text_generation, module_name=Heads.text_generation) | |||||
| class TextGenerationHead(TorchHead): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| config = self.config | |||||
| self.linear = nn.Linear( | |||||
| config['hidden_size'], config['vocab_size'], bias=False) | |||||
| def get_output_embeddings(self): | |||||
| return self.linear | |||||
| def forward(self, inputs=None): | |||||
| logits = self.linear(inputs) | |||||
| return {OutputKeys.LOGITS: logits} | |||||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||||
| labels) -> Dict[str, torch.Tensor]: | |||||
| logits = outputs[OutputKeys.LOGITS] | |||||
| return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||||
| from .sequence_classification import SequenceClassificationModel | from .sequence_classification import SequenceClassificationModel | ||||
| from .task_model import SingleBackboneTaskModelBase | from .task_model import SingleBackboneTaskModelBase | ||||
| from .token_classification import TokenClassificationModel | from .token_classification import TokenClassificationModel | ||||
| from .text_generation import TaskModelForTextGeneration | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -19,6 +20,7 @@ else: | |||||
| 'sequence_classification': ['SequenceClassificationModel'], | 'sequence_classification': ['SequenceClassificationModel'], | ||||
| 'task_model': ['SingleBackboneTaskModelBase'], | 'task_model': ['SingleBackboneTaskModelBase'], | ||||
| 'token_classification': ['TokenClassificationModel'], | 'token_classification': ['TokenClassificationModel'], | ||||
| 'text_generation': ['TaskModelForTextGeneration'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,79 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| import addict | |||||
| import numpy as np | |||||
| from transformers.modeling_utils import PreTrainedModel | |||||
| from modelscope.metainfo import TaskModels | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp.task_models.task_model import \ | |||||
| SingleBackboneTaskModelBase | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['TaskModelForTextGeneration'] | |||||
| @MODELS.register_module( | |||||
| Tasks.text_generation, module_name=TaskModels.text_generation) | |||||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| if 'base_model_prefix' in kwargs: | |||||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||||
| self.build_backbone(self.backbone_cfg) | |||||
| self.build_head(self.head_cfg) | |||||
| if self.config.get('shared_embedding', False): | |||||
| input_embeddings = self.backbone.get_input_embeddings() | |||||
| output_embeddings = self.head.get_output_embeddings() | |||||
| output_embeddings.weight = input_embeddings.weight | |||||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| # backbone do not need labels, only head need for loss compute | |||||
| labels = input.pop(OutputKeys.LABELS, None) | |||||
| backbone_outputs = super().forward(input) | |||||
| hidden_states = backbone_outputs[0] | |||||
| outputs = self.head.forward(hidden_states) | |||||
| if labels is not None: | |||||
| input[OutputKeys.LABELS] = labels | |||||
| loss = self.compute_loss(outputs, labels) | |||||
| outputs.update(loss) | |||||
| return addict.Dict(outputs) | |||||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |||||
| token_type_ids = kwargs.get('token_type_ids', None) | |||||
| # only last token for inputs_ids if past is defined in kwargs | |||||
| if past: | |||||
| input_ids = input_ids[:, -1].unsqueeze(-1) | |||||
| if token_type_ids is not None: | |||||
| token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |||||
| attention_mask = kwargs.get('attention_mask', None) | |||||
| position_ids = kwargs.get('position_ids', None) | |||||
| if attention_mask is not None and position_ids is None: | |||||
| # create position_ids on the fly for batch generation | |||||
| position_ids = attention_mask.long().cumsum(-1) - 1 | |||||
| position_ids.masked_fill_(attention_mask == 0, 1) | |||||
| if past: | |||||
| position_ids = position_ids[:, -1].unsqueeze(-1) | |||||
| else: | |||||
| position_ids = None | |||||
| return { | |||||
| 'input_ids': input_ids, | |||||
| 'past_key_values': past, | |||||
| 'use_cache': kwargs.get('use_cache'), | |||||
| 'position_ids': position_ids, | |||||
| 'attention_mask': attention_mask, | |||||
| 'token_type_ids': token_type_ids, | |||||
| } | |||||
| @@ -6,10 +6,12 @@ import torch | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | from modelscope.pipelines.base import Pipeline, Tensor | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||||
| from modelscope.utils.constant import Fields, Tasks | |||||
| from modelscope.utils.hub import read_config | |||||
| __all__ = ['TextGenerationPipeline'] | __all__ = ['TextGenerationPipeline'] | ||||
| @@ -20,7 +22,7 @@ class TextGenerationPipeline(Pipeline): | |||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: Optional[TextGenerationPreprocessor] = None, | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| first_sequence='sentence', | first_sequence='sentence', | ||||
| **kwargs): | **kwargs): | ||||
| """Use `model` and `preprocessor` to create a generation pipeline for prediction. | """Use `model` and `preprocessor` to create a generation pipeline for prediction. | ||||
| @@ -50,19 +52,34 @@ class TextGenerationPipeline(Pipeline): | |||||
| """ | """ | ||||
| model = model if isinstance(model, | model = model if isinstance(model, | ||||
| Model) else Model.from_pretrained(model) | Model) else Model.from_pretrained(model) | ||||
| cfg = read_config(model.model_dir) | |||||
| self.postprocessor = cfg.pop('postprocessor', None) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = TextGenerationPreprocessor( | |||||
| preprocessor_cfg = cfg.preprocessor | |||||
| preprocessor_cfg.update({ | |||||
| 'model_dir': | |||||
| model.model_dir, | model.model_dir, | ||||
| first_sequence=first_sequence, | |||||
| second_sequence=None, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||||
| 'first_sequence': | |||||
| first_sequence, | |||||
| 'second_sequence': | |||||
| None, | |||||
| 'sequence_length': | |||||
| kwargs.pop('sequence_length', 128) | |||||
| }) | |||||
| preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) | |||||
| model.eval() | model.eval() | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| def _sanitize_parameters(self, **pipeline_parameters): | |||||
| return {}, pipeline_parameters, {} | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| return self.model.generate(inputs) | |||||
| return self.model.generate(inputs, **forward_params) | |||||
| def sentence_piece(self, inputs) -> Dict[str, Tensor]: | |||||
| return self.preprocessor.tokenizer.decode(inputs.tolist())[0] | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | def postprocess(self, inputs: Dict[str, Tensor], | ||||
| **postprocess_params) -> Dict[str, str]: | **postprocess_params) -> Dict[str, str]: | ||||
| @@ -74,4 +91,7 @@ class TextGenerationPipeline(Pipeline): | |||||
| Returns: | Returns: | ||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| return inputs | |||||
| return inputs if self.postprocessor is None else { | |||||
| OutputKeys.TEXT: | |||||
| getattr(self, self.postprocessor.replace('-', '_'))(inputs) | |||||
| } | |||||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||||
| Tokenize, | Tokenize, | ||||
| WordSegmentationBlankSetToLabelPreprocessor, | WordSegmentationBlankSetToLabelPreprocessor, | ||||
| ZeroShotClassificationPreprocessor, | ZeroShotClassificationPreprocessor, | ||||
| SentencePiecePreprocessor, | |||||
| ) | ) | ||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| @@ -71,6 +72,7 @@ else: | |||||
| 'Text2TextGenerationPreprocessor', | 'Text2TextGenerationPreprocessor', | ||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', | 'ZeroShotClassificationPreprocessor', | ||||
| 'SentencePiecePreprocessor', | |||||
| ], | ], | ||||
| 'space': [ | 'space': [ | ||||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | ||||
| @@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||||
| Tokenize, | Tokenize, | ||||
| WordSegmentationBlankSetToLabelPreprocessor, | WordSegmentationBlankSetToLabelPreprocessor, | ||||
| ZeroShotClassificationPreprocessor, | ZeroShotClassificationPreprocessor, | ||||
| SentencePiecePreprocessor, | |||||
| ) | ) | ||||
| else: | else: | ||||
| @@ -41,6 +42,7 @@ else: | |||||
| 'Text2TextGenerationPreprocessor', | 'Text2TextGenerationPreprocessor', | ||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', | 'ZeroShotClassificationPreprocessor', | ||||
| 'SentencePiecePreprocessor', | |||||
| ], | ], | ||||
| 'text_error_correction': [ | 'text_error_correction': [ | ||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| @@ -5,6 +5,7 @@ import re | |||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | from typing import Any, Dict, Iterable, Optional, Tuple, Union | ||||
| import numpy as np | import numpy as np | ||||
| import sentencepiece as spm | |||||
| import torch | import torch | ||||
| from transformers import AutoTokenizer | from transformers import AutoTokenizer | ||||
| @@ -1160,3 +1161,23 @@ class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| self.labels_to_id(labels, output) | self.labels_to_id(labels, output) | ||||
| return output | return output | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.sentence_piece) | |||||
| class SentencePiecePreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| import os | |||||
| super().__init__(*args, **kwargs) | |||||
| self.tokenizer = None | |||||
| for file_name in os.listdir(model_dir): | |||||
| if file_name.endswith('.model'): | |||||
| m_file = osp.join(model_dir, file_name) | |||||
| self.tokenizer = spm.SentencePieceProcessor(model_file=m_file) | |||||
| break | |||||
| assert self.tokenizer is not None, 'Can not find .model file' | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) | |||||
| @@ -133,6 +133,19 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||
| @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||||
| def test_gpt_neo(self): | |||||
| pipe = pipeline( | |||||
| task=Tasks.text_generation, model='Langboat/mengzi-gpt-neo-base') | |||||
| print( | |||||
| pipe( | |||||
| '我是', | |||||
| do_sample=True, | |||||
| top_k=5, | |||||
| top_p=1, | |||||
| max_length=20, | |||||
| repetition_penalty=0.5)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): | |||||
| self.assertIsInstance(from_imports, dict) | self.assertIsInstance(from_imports, dict) | ||||
| self.assertIsInstance(decorators, list) | self.assertIsInstance(decorators, list) | ||||
| self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | ||||
| self.assertEqual(len(from_imports.keys()), 7) | |||||
| self.assertEqual(len(from_imports.keys()), 9) | |||||
| self.assertTrue(from_imports['modelscope.metainfo'] is not None) | self.assertTrue(from_imports['modelscope.metainfo'] is not None) | ||||
| self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | ||||
| self.assertEqual(decorators, | self.assertEqual(decorators, | ||||