1. 添加 gpt_neo 模型,因 checkpoint 归属于 Langboat 还未上传到模型库,已线下完成测试
2. 添加 text-generation task models 与 head,后续会将 gpt3,palm 等已上线文本生成模型统一为 backbone + head 结构的 task models
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10404249
master
| @@ -71,6 +71,7 @@ class Models(object): | |||
| gcnncrf = 'gcnn-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| gpt_neo = 'gpt-neo' | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| ponet = 'ponet' | |||
| @@ -101,6 +102,7 @@ class TaskModels(object): | |||
| information_extraction = 'information-extraction' | |||
| fill_mask = 'fill-mask' | |||
| feature_extraction = 'feature-extraction' | |||
| text_generation = 'text-generation' | |||
| class Heads(object): | |||
| @@ -116,6 +118,8 @@ class Heads(object): | |||
| token_classification = 'token-classification' | |||
| # extraction | |||
| information_extraction = 'information-extraction' | |||
| # text gen | |||
| text_generation = 'text-generation' | |||
| class Pipelines(object): | |||
| @@ -341,6 +345,7 @@ class Preprocessors(object): | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| sentence_piece = 'sentence-piece' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -30,7 +30,8 @@ if TYPE_CHECKING: | |||
| InformationExtractionModel, | |||
| SequenceClassificationModel, | |||
| SingleBackboneTaskModelBase, | |||
| TokenClassificationModel) | |||
| TokenClassificationModel, | |||
| TaskModelForTextGeneration) | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sentence_embedding import SentenceEmbedding | |||
| from .passage_ranking import PassageRanking | |||
| @@ -69,6 +70,7 @@ else: | |||
| 'SequenceClassificationModel', | |||
| 'SingleBackboneTaskModelBase', | |||
| 'TokenClassificationModel', | |||
| 'TaskModelForTextGeneration', | |||
| ], | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from transformers import GPTNeoConfig | |||
| from transformers import GPTNeoModel as GPTNeoModelTransform | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import BACKBONES | |||
| from modelscope.utils.constant import Fields | |||
| @BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.gpt_neo) | |||
| class GPTNeoModel(GPTNeoModelTransform): | |||
| def __init__(self, **kwargs): | |||
| config = GPTNeoConfig(**kwargs) | |||
| super().__init__(config) | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from modelscope.metainfo import Heads | |||
| from modelscope.models.base import TorchHead | |||
| from modelscope.models.builder import HEADS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| @HEADS.register_module( | |||
| Tasks.text_generation, module_name=Heads.text_generation) | |||
| class TextGenerationHead(TorchHead): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| config = self.config | |||
| self.linear = nn.Linear( | |||
| config['hidden_size'], config['vocab_size'], bias=False) | |||
| def get_output_embeddings(self): | |||
| return self.linear | |||
| def forward(self, inputs=None): | |||
| logits = self.linear(inputs) | |||
| return {OutputKeys.LOGITS: logits} | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| labels) -> Dict[str, torch.Tensor]: | |||
| logits = outputs[OutputKeys.LOGITS] | |||
| return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||
| from .sequence_classification import SequenceClassificationModel | |||
| from .task_model import SingleBackboneTaskModelBase | |||
| from .token_classification import TokenClassificationModel | |||
| from .text_generation import TaskModelForTextGeneration | |||
| else: | |||
| _import_structure = { | |||
| @@ -19,6 +20,7 @@ else: | |||
| 'sequence_classification': ['SequenceClassificationModel'], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'token_classification': ['TokenClassificationModel'], | |||
| 'text_generation': ['TaskModelForTextGeneration'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,79 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| import addict | |||
| import numpy as np | |||
| from transformers.modeling_utils import PreTrainedModel | |||
| from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['TaskModelForTextGeneration'] | |||
| @MODELS.register_module( | |||
| Tasks.text_generation, module_name=TaskModels.text_generation) | |||
| class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| if 'base_model_prefix' in kwargs: | |||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||
| self.build_backbone(self.backbone_cfg) | |||
| self.build_head(self.head_cfg) | |||
| if self.config.get('shared_embedding', False): | |||
| input_embeddings = self.backbone.get_input_embeddings() | |||
| output_embeddings = self.head.get_output_embeddings() | |||
| output_embeddings.weight = input_embeddings.weight | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| # backbone do not need labels, only head need for loss compute | |||
| labels = input.pop(OutputKeys.LABELS, None) | |||
| backbone_outputs = super().forward(input) | |||
| hidden_states = backbone_outputs[0] | |||
| outputs = self.head.forward(hidden_states) | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| loss = self.compute_loss(outputs, labels) | |||
| outputs.update(loss) | |||
| return addict.Dict(outputs) | |||
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |||
| token_type_ids = kwargs.get('token_type_ids', None) | |||
| # only last token for inputs_ids if past is defined in kwargs | |||
| if past: | |||
| input_ids = input_ids[:, -1].unsqueeze(-1) | |||
| if token_type_ids is not None: | |||
| token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |||
| attention_mask = kwargs.get('attention_mask', None) | |||
| position_ids = kwargs.get('position_ids', None) | |||
| if attention_mask is not None and position_ids is None: | |||
| # create position_ids on the fly for batch generation | |||
| position_ids = attention_mask.long().cumsum(-1) - 1 | |||
| position_ids.masked_fill_(attention_mask == 0, 1) | |||
| if past: | |||
| position_ids = position_ids[:, -1].unsqueeze(-1) | |||
| else: | |||
| position_ids = None | |||
| return { | |||
| 'input_ids': input_ids, | |||
| 'past_key_values': past, | |||
| 'use_cache': kwargs.get('use_cache'), | |||
| 'position_ids': position_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'token_type_ids': token_type_ids, | |||
| } | |||
| @@ -6,10 +6,12 @@ import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
| from modelscope.utils.constant import Fields, Tasks | |||
| from modelscope.utils.hub import read_config | |||
| __all__ = ['TextGenerationPipeline'] | |||
| @@ -20,7 +22,7 @@ class TextGenerationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[TextGenerationPreprocessor] = None, | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| first_sequence='sentence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a generation pipeline for prediction. | |||
| @@ -50,19 +52,34 @@ class TextGenerationPipeline(Pipeline): | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| cfg = read_config(model.model_dir) | |||
| self.postprocessor = cfg.pop('postprocessor', None) | |||
| if preprocessor is None: | |||
| preprocessor = TextGenerationPreprocessor( | |||
| preprocessor_cfg = cfg.preprocessor | |||
| preprocessor_cfg.update({ | |||
| 'model_dir': | |||
| model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=None, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| 'first_sequence': | |||
| first_sequence, | |||
| 'second_sequence': | |||
| None, | |||
| 'sequence_length': | |||
| kwargs.pop('sequence_length', 128) | |||
| }) | |||
| preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def _sanitize_parameters(self, **pipeline_parameters): | |||
| return {}, pipeline_parameters, {} | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model.generate(inputs) | |||
| return self.model.generate(inputs, **forward_params) | |||
| def sentence_piece(self, inputs) -> Dict[str, Tensor]: | |||
| return self.preprocessor.tokenizer.decode(inputs.tolist())[0] | |||
| def postprocess(self, inputs: Dict[str, Tensor], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| @@ -74,4 +91,7 @@ class TextGenerationPipeline(Pipeline): | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| return inputs | |||
| return inputs if self.postprocessor is None else { | |||
| OutputKeys.TEXT: | |||
| getattr(self, self.postprocessor.replace('-', '_'))(inputs) | |||
| } | |||
| @@ -32,6 +32,7 @@ if TYPE_CHECKING: | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| SentencePiecePreprocessor, | |||
| ) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| @@ -71,6 +72,7 @@ else: | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'SentencePiecePreprocessor', | |||
| ], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| @@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| SentencePiecePreprocessor, | |||
| ) | |||
| else: | |||
| @@ -41,6 +42,7 @@ else: | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'SentencePiecePreprocessor', | |||
| ], | |||
| 'text_error_correction': [ | |||
| 'TextErrorCorrectionPreprocessor', | |||
| @@ -5,6 +5,7 @@ import re | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| import numpy as np | |||
| import sentencepiece as spm | |||
| import torch | |||
| from transformers import AutoTokenizer | |||
| @@ -1160,3 +1161,23 @@ class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): | |||
| self.labels_to_id(labels, output) | |||
| return output | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sentence_piece) | |||
| class SentencePiecePreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| import os | |||
| super().__init__(*args, **kwargs) | |||
| self.tokenizer = None | |||
| for file_name in os.listdir(model_dir): | |||
| if file_name.endswith('.model'): | |||
| m_file = osp.join(model_dir, file_name) | |||
| self.tokenizer = spm.SentencePieceProcessor(model_file=m_file) | |||
| break | |||
| assert self.tokenizer is not None, 'Can not find .model file' | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) | |||
| @@ -133,6 +133,19 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") | |||
| def test_gpt_neo(self): | |||
| pipe = pipeline( | |||
| task=Tasks.text_generation, model='Langboat/mengzi-gpt-neo-base') | |||
| print( | |||
| pipe( | |||
| '我是', | |||
| do_sample=True, | |||
| top_k=5, | |||
| top_p=1, | |||
| max_length=20, | |||
| repetition_penalty=0.5)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): | |||
| self.assertIsInstance(from_imports, dict) | |||
| self.assertIsInstance(decorators, list) | |||
| self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | |||
| self.assertEqual(len(from_imports.keys()), 7) | |||
| self.assertEqual(len(from_imports.keys()), 9) | |||
| self.assertTrue(from_imports['modelscope.metainfo'] is not None) | |||
| self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | |||
| self.assertEqual(decorators, | |||