From 271e2a2a9916de3bd64e40dd4c836d341fed4b77 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Mon, 17 Oct 2022 20:54:29 +0800 Subject: [PATCH] [to #42322933] Add gpt_neo model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 添加 gpt_neo 模型,因 checkpoint 归属于 Langboat 还未上传到模型库,已线下完成测试 2. 添加 text-generation task models 与 head,后续会将 gpt3,palm 等已上线文本生成模型统一为 backbone + head 结构的 task models Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10404249 --- modelscope/metainfo.py | 5 ++ modelscope/models/nlp/__init__.py | 4 +- modelscope/models/nlp/backbones/gpt_neo.py | 15 ++++ .../models/nlp/heads/text_generation_head.py | 35 ++++++++ modelscope/models/nlp/task_models/__init__.py | 2 + .../models/nlp/task_models/text_generation.py | 79 +++++++++++++++++++ .../pipelines/nlp/text_generation_pipeline.py | 38 ++++++--- modelscope/preprocessors/__init__.py | 2 + modelscope/preprocessors/nlp/__init__.py | 2 + modelscope/preprocessors/nlp/nlp_base.py | 21 +++++ tests/pipelines/test_text_generation.py | 13 +++ tests/utils/test_ast.py | 2 +- 12 files changed, 207 insertions(+), 11 deletions(-) create mode 100644 modelscope/models/nlp/backbones/gpt_neo.py create mode 100644 modelscope/models/nlp/heads/text_generation_head.py create mode 100644 modelscope/models/nlp/task_models/text_generation.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2e3fed98..fb99bc71 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -71,6 +71,7 @@ class Models(object): gcnncrf = 'gcnn-crf' bart = 'bart' gpt3 = 'gpt3' + gpt_neo = 'gpt-neo' plug = 'plug' bert_for_ds = 'bert-for-document-segmentation' ponet = 'ponet' @@ -101,6 +102,7 @@ class TaskModels(object): information_extraction = 'information-extraction' fill_mask = 'fill-mask' feature_extraction = 'feature-extraction' + text_generation = 'text-generation' class Heads(object): @@ -116,6 +118,8 @@ class Heads(object): token_classification = 'token-classification' # extraction information_extraction = 'information-extraction' + # text gen + text_generation = 'text-generation' class Pipelines(object): @@ -341,6 +345,7 @@ class Preprocessors(object): re_tokenizer = 're-tokenizer' document_segmentation = 'document-segmentation' feature_extraction = 'feature-extraction' + sentence_piece = 'sentence-piece' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 8ef96365..9e830d17 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -30,7 +30,8 @@ if TYPE_CHECKING: InformationExtractionModel, SequenceClassificationModel, SingleBackboneTaskModelBase, - TokenClassificationModel) + TokenClassificationModel, + TaskModelForTextGeneration) from .token_classification import SbertForTokenClassification from .sentence_embedding import SentenceEmbedding from .passage_ranking import PassageRanking @@ -69,6 +70,7 @@ else: 'SequenceClassificationModel', 'SingleBackboneTaskModelBase', 'TokenClassificationModel', + 'TaskModelForTextGeneration', ], 'token_classification': ['SbertForTokenClassification'], 'table_question_answering': ['TableQuestionAnswering'], diff --git a/modelscope/models/nlp/backbones/gpt_neo.py b/modelscope/models/nlp/backbones/gpt_neo.py new file mode 100644 index 00000000..a2d0c374 --- /dev/null +++ b/modelscope/models/nlp/backbones/gpt_neo.py @@ -0,0 +1,15 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import GPTNeoConfig +from transformers import GPTNeoModel as GPTNeoModelTransform + +from modelscope.metainfo import Models +from modelscope.models.builder import BACKBONES +from modelscope.utils.constant import Fields + + +@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.gpt_neo) +class GPTNeoModel(GPTNeoModelTransform): + + def __init__(self, **kwargs): + config = GPTNeoConfig(**kwargs) + super().__init__(config) diff --git a/modelscope/models/nlp/heads/text_generation_head.py b/modelscope/models/nlp/heads/text_generation_head.py new file mode 100644 index 00000000..606d5a1f --- /dev/null +++ b/modelscope/models/nlp/heads/text_generation_head.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.text_generation, module_name=Heads.text_generation) +class TextGenerationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.linear = nn.Linear( + config['hidden_size'], config['vocab_size'], bias=False) + + def get_output_embeddings(self): + return self.linear + + def forward(self, inputs=None): + logits = self.linear(inputs) + return {OutputKeys.LOGITS: logits} + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + logits = outputs[OutputKeys.LOGITS] + return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index 90f22aa1..38359044 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from .sequence_classification import SequenceClassificationModel from .task_model import SingleBackboneTaskModelBase from .token_classification import TokenClassificationModel + from .text_generation import TaskModelForTextGeneration else: _import_structure = { @@ -19,6 +20,7 @@ else: 'sequence_classification': ['SequenceClassificationModel'], 'task_model': ['SingleBackboneTaskModelBase'], 'token_classification': ['TokenClassificationModel'], + 'text_generation': ['TaskModelForTextGeneration'], } import sys diff --git a/modelscope/models/nlp/task_models/text_generation.py b/modelscope/models/nlp/task_models/text_generation.py new file mode 100644 index 00000000..973198ae --- /dev/null +++ b/modelscope/models/nlp/task_models/text_generation.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import addict +import numpy as np +from transformers.modeling_utils import PreTrainedModel + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + +__all__ = ['TaskModelForTextGeneration'] + + +@MODELS.register_module( + Tasks.text_generation, module_name=TaskModels.text_generation) +class TaskModelForTextGeneration(SingleBackboneTaskModelBase, PreTrainedModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + self.build_backbone(self.backbone_cfg) + self.build_head(self.head_cfg) + if self.config.get('shared_embedding', False): + input_embeddings = self.backbone.get_input_embeddings() + output_embeddings = self.head.get_output_embeddings() + output_embeddings.weight = input_embeddings.weight + + def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: + # backbone do not need labels, only head need for loss compute + labels = input.pop(OutputKeys.LABELS, None) + + backbone_outputs = super().forward(input) + hidden_states = backbone_outputs[0] + + outputs = self.head.forward(hidden_states) + if labels is not None: + input[OutputKeys.LABELS] = labels + loss = self.compute_loss(outputs, labels) + outputs.update(loss) + return addict.Dict(outputs) + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get('token_type_ids', None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get('attention_mask', None) + position_ids = kwargs.get('position_ids', None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + 'input_ids': input_ids, + 'past_key_values': past, + 'use_cache': kwargs.get('use_cache'), + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'token_type_ids': token_type_ids, + } diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index ea35763f..ae92f26a 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -6,10 +6,12 @@ import torch from modelscope.metainfo import Pipelines from modelscope.models.base import Model +from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import TextGenerationPreprocessor -from modelscope.utils.constant import Tasks +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.hub import read_config __all__ = ['TextGenerationPipeline'] @@ -20,7 +22,7 @@ class TextGenerationPipeline(Pipeline): def __init__(self, model: Union[Model, str], - preprocessor: Optional[TextGenerationPreprocessor] = None, + preprocessor: Optional[Preprocessor] = None, first_sequence='sentence', **kwargs): """Use `model` and `preprocessor` to create a generation pipeline for prediction. @@ -50,19 +52,34 @@ class TextGenerationPipeline(Pipeline): """ model = model if isinstance(model, Model) else Model.from_pretrained(model) + cfg = read_config(model.model_dir) + self.postprocessor = cfg.pop('postprocessor', None) if preprocessor is None: - preprocessor = TextGenerationPreprocessor( + preprocessor_cfg = cfg.preprocessor + preprocessor_cfg.update({ + 'model_dir': model.model_dir, - first_sequence=first_sequence, - second_sequence=None, - sequence_length=kwargs.pop('sequence_length', 128)) + 'first_sequence': + first_sequence, + 'second_sequence': + None, + 'sequence_length': + kwargs.pop('sequence_length', 128) + }) + preprocessor = build_preprocessor(preprocessor_cfg, Fields.nlp) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) + def _sanitize_parameters(self, **pipeline_parameters): + return {}, pipeline_parameters, {} + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: with torch.no_grad(): - return self.model.generate(inputs) + return self.model.generate(inputs, **forward_params) + + def sentence_piece(self, inputs) -> Dict[str, Tensor]: + return self.preprocessor.tokenizer.decode(inputs.tolist())[0] def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]: @@ -74,4 +91,7 @@ class TextGenerationPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ - return inputs + return inputs if self.postprocessor is None else { + OutputKeys.TEXT: + getattr(self, self.postprocessor.replace('-', '_'))(inputs) + } diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 90303b65..43fa64a7 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: Tokenize, WordSegmentationBlankSetToLabelPreprocessor, ZeroShotClassificationPreprocessor, + SentencePiecePreprocessor, ) from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, @@ -71,6 +72,7 @@ else: 'Text2TextGenerationPreprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', + 'SentencePiecePreprocessor', ], 'space': [ 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index dfbb5c81..a753fe6c 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: Tokenize, WordSegmentationBlankSetToLabelPreprocessor, ZeroShotClassificationPreprocessor, + SentencePiecePreprocessor, ) else: @@ -41,6 +42,7 @@ else: 'Text2TextGenerationPreprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', + 'SentencePiecePreprocessor', ], 'text_error_correction': [ 'TextErrorCorrectionPreprocessor', diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index bec7e4e1..3d708634 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -5,6 +5,7 @@ import re from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np +import sentencepiece as spm import torch from transformers import AutoTokenizer @@ -1160,3 +1161,23 @@ class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): self.labels_to_id(labels, output) return output + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sentence_piece) +class SentencePiecePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + import os + + super().__init__(*args, **kwargs) + self.tokenizer = None + for file_name in os.listdir(model_dir): + if file_name.endswith('.model'): + m_file = osp.join(model_dir, file_name) + self.tokenizer = spm.SentencePieceProcessor(model_file=m_file) + break + assert self.tokenizer is not None, 'Can not find .model file' + + def __call__(self, data: str) -> Dict[str, Any]: + return torch.tensor(self.tokenizer.encode([data]), dtype=torch.long) diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 66f9c9da..5a270f83 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -133,6 +133,19 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): def test_demo_compatibility(self): self.compatibility_check() + @unittest.skip("Langboat's checkpoint has not been uploaded to modelhub") + def test_gpt_neo(self): + pipe = pipeline( + task=Tasks.text_generation, model='Langboat/mengzi-gpt-neo-base') + print( + pipe( + '我是', + do_sample=True, + top_k=5, + top_p=1, + max_length=20, + repetition_penalty=0.5)) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_ast.py b/tests/utils/test_ast.py index 9a8ab828..c0624679 100644 --- a/tests/utils/test_ast.py +++ b/tests/utils/test_ast.py @@ -41,7 +41,7 @@ class AstScaningTest(unittest.TestCase): self.assertIsInstance(from_imports, dict) self.assertIsInstance(decorators, list) self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) - self.assertEqual(len(from_imports.keys()), 7) + self.assertEqual(len(from_imports.keys()), 9) self.assertTrue(from_imports['modelscope.metainfo'] is not None) self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) self.assertEqual(decorators,