Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10191736 * add T5 for generationmaster
| @@ -65,6 +65,7 @@ class Models(object): | |||||
| plug = 'plug' | plug = 'plug' | ||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| T5 = 'T5' | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| @@ -179,6 +180,7 @@ class Pipelines(object): | |||||
| part_of_speech = 'part-of-speech' | part_of_speech = 'part-of-speech' | ||||
| named_entity_recognition = 'named-entity-recognition' | named_entity_recognition = 'named-entity-recognition' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| text2text_generation = 'text2text-generation' | |||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentiment_classification = 'sentiment-classification' | sentiment_classification = 'sentiment-classification' | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| @@ -280,6 +282,7 @@ class Preprocessors(object): | |||||
| cross_encoder_tokenizer = 'cross-encoder-tokenizer' | cross_encoder_tokenizer = 'cross-encoder-tokenizer' | ||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | ||||
| text_gen_tokenizer = 'text-gen-tokenizer' | text_gen_tokenizer = 'text-gen-tokenizer' | ||||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||||
| token_cls_tokenizer = 'token-cls-tokenizer' | token_cls_tokenizer = 'token-cls-tokenizer' | ||||
| ner_tokenizer = 'ner-tokenizer' | ner_tokenizer = 'ner-tokenizer' | ||||
| nli_tokenizer = 'nli-tokenizer' | nli_tokenizer = 'nli-tokenizer' | ||||
| @@ -0,0 +1,21 @@ | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .t5_for_text_generation import T5ForConditionalGeneration | |||||
| else: | |||||
| _import_structure = { | |||||
| 't5_for_text_generation': ['T5ForConditionalGeneration'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,174 @@ | |||||
| # Copyright 2020, The T5 Authors and HuggingFace Inc. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ T5 model configuration""" | |||||
| from typing import Mapping | |||||
| from transformers.configuration_utils import PretrainedConfig | |||||
| from transformers.onnx import OnnxSeq2SeqConfigWithPast | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger(__name__) | |||||
| class T5Config(PretrainedConfig): | |||||
| r""" | |||||
| This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to | |||||
| instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a | |||||
| configuration with the defaults will yield a similar configuration to that of the T5 | |||||
| [t5-small](https://huggingface.co/t5-small) architecture. | |||||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||||
| documentation from [`PretrainedConfig`] for more information. | |||||
| Arguments: | |||||
| vocab_size (`int`, *optional*, defaults to 32128): | |||||
| Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the | |||||
| `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. | |||||
| d_model (`int`, *optional*, defaults to 512): | |||||
| Size of the encoder layers and the pooler layer. | |||||
| d_kv (`int`, *optional*, defaults to 64): | |||||
| Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // | |||||
| num_heads`. | |||||
| d_ff (`int`, *optional*, defaults to 2048): | |||||
| Size of the intermediate feed forward layer in each `T5Block`. | |||||
| num_layers (`int`, *optional*, defaults to 6): | |||||
| Number of hidden layers in the Transformer encoder. | |||||
| num_decoder_layers (`int`, *optional*): | |||||
| Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. | |||||
| num_heads (`int`, *optional*, defaults to 8): | |||||
| Number of attention heads for each attention layer in the Transformer encoder. | |||||
| relative_attention_num_buckets (`int`, *optional*, defaults to 32): | |||||
| The number of buckets to use for each attention layer. | |||||
| relative_attention_max_distance (`int`, *optional*, defaults to 128): | |||||
| The maximum distance of the longer sequences for the bucket separation. | |||||
| dropout_rate (`float`, *optional*, defaults to 0.1): | |||||
| The ratio for all dropout layers. | |||||
| layer_norm_eps (`float`, *optional*, defaults to 1e-6): | |||||
| The epsilon used by the layer normalization layers. | |||||
| initializer_factor (`float`, *optional*, defaults to 1): | |||||
| A factor for initializing all weight matrices (should be kept to 1, used internally for initialization | |||||
| testing). | |||||
| feed_forward_proj (`string`, *optional*, defaults to `"relu"`): | |||||
| Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the | |||||
| `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. | |||||
| use_cache (`bool`, *optional*, defaults to `True`): | |||||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||||
| """ | |||||
| model_type = 't5' | |||||
| keys_to_ignore_at_inference = ['past_key_values'] | |||||
| attribute_map = { | |||||
| 'hidden_size': 'd_model', | |||||
| 'num_attention_heads': 'num_heads', | |||||
| 'num_hidden_layers': 'num_layers' | |||||
| } | |||||
| def __init__(self, | |||||
| vocab_size=32128, | |||||
| d_model=512, | |||||
| d_kv=64, | |||||
| d_ff=2048, | |||||
| num_layers=6, | |||||
| num_decoder_layers=None, | |||||
| num_heads=8, | |||||
| relative_attention_num_buckets=32, | |||||
| relative_attention_max_distance=128, | |||||
| dropout_rate=0.1, | |||||
| layer_norm_epsilon=1e-6, | |||||
| initializer_factor=1.0, | |||||
| feed_forward_proj='relu', | |||||
| is_encoder_decoder=True, | |||||
| use_cache=True, | |||||
| pad_token_id=0, | |||||
| eos_token_id=1, | |||||
| **kwargs): | |||||
| self.vocab_size = vocab_size | |||||
| self.d_model = d_model | |||||
| self.d_kv = d_kv | |||||
| self.d_ff = d_ff | |||||
| self.num_layers = num_layers | |||||
| self.num_decoder_layers = (num_decoder_layers if num_decoder_layers | |||||
| is not None else self.num_layers | |||||
| ) # default = symmetry | |||||
| self.num_heads = num_heads | |||||
| self.relative_attention_num_buckets = relative_attention_num_buckets | |||||
| self.relative_attention_max_distance = relative_attention_max_distance | |||||
| self.dropout_rate = dropout_rate | |||||
| self.layer_norm_epsilon = layer_norm_epsilon | |||||
| self.initializer_factor = initializer_factor | |||||
| self.feed_forward_proj = feed_forward_proj | |||||
| self.use_cache = use_cache | |||||
| act_info = self.feed_forward_proj.split('-') | |||||
| self.dense_act_fn = act_info[-1] | |||||
| self.is_gated_act = act_info[0] == 'gated' | |||||
| if len(act_info) > 1 and act_info[0] != 'gated' or len(act_info) > 2: | |||||
| raise ValueError( | |||||
| f'`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer.' | |||||
| 'Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. ' | |||||
| "'gated-gelu' or 'relu'") | |||||
| # for backwards compatibility | |||||
| if feed_forward_proj == 'gated-gelu': | |||||
| self.dense_act_fn = 'gelu_new' | |||||
| super().__init__( | |||||
| pad_token_id=pad_token_id, | |||||
| eos_token_id=eos_token_id, | |||||
| is_encoder_decoder=is_encoder_decoder, | |||||
| **kwargs, | |||||
| ) | |||||
| class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): | |||||
| @property | |||||
| def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||||
| common_inputs = { | |||||
| 'input_ids': { | |||||
| 0: 'batch', | |||||
| 1: 'encoder_sequence' | |||||
| }, | |||||
| 'attention_mask': { | |||||
| 0: 'batch', | |||||
| 1: 'encoder_sequence' | |||||
| }, | |||||
| } | |||||
| if self.use_past: | |||||
| common_inputs['attention_mask'][ | |||||
| 1] = 'past_encoder_sequence + sequence' | |||||
| common_inputs['decoder_input_ids'] = {0: 'batch'} | |||||
| common_inputs['decoder_attention_mask'] = { | |||||
| 0: 'batch', | |||||
| 1: 'past_decoder_sequence + sequence' | |||||
| } | |||||
| else: | |||||
| common_inputs['decoder_input_ids'] = { | |||||
| 0: 'batch', | |||||
| 1: 'decoder_sequence' | |||||
| } | |||||
| common_inputs['decoder_attention_mask'] = { | |||||
| 0: 'batch', | |||||
| 1: 'decoder_sequence' | |||||
| } | |||||
| if self.use_past: | |||||
| self.fill_with_past_key_values_(common_inputs, direction='inputs') | |||||
| return common_inputs | |||||
| @property | |||||
| def default_onnx_opset(self) -> int: | |||||
| return 13 | |||||
| @@ -0,0 +1,56 @@ | |||||
| from typing import Optional, Tuple | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .modeling_t5 import T5Config | |||||
| from .modeling_t5 import T5ForConditionalGeneration as T5ForGeneration | |||||
| @MODELS.register_module( | |||||
| group_key=Tasks.text2text_generation, | |||||
| module_name=Models.T5, | |||||
| ) | |||||
| class T5ForConditionalGeneration(TorchModel): | |||||
| def __init__(self, model_dir=None, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model = T5ForGeneration.from_pretrained(model_dir) | |||||
| self.generate = self.model.generate | |||||
| self.config = self.model.config | |||||
| def forward(self, | |||||
| input_ids: Optional[torch.LongTensor] = None, | |||||
| attention_mask: Optional[torch.FloatTensor] = None, | |||||
| decoder_input_ids: Optional[torch.LongTensor] = None, | |||||
| decoder_attention_mask: Optional[torch.BoolTensor] = None, | |||||
| head_mask: Optional[torch.FloatTensor] = None, | |||||
| decoder_head_mask: Optional[torch.FloatTensor] = None, | |||||
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |||||
| encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||||
| inputs_embeds: Optional[torch.FloatTensor] = None, | |||||
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |||||
| labels: Optional[torch.LongTensor] = None, | |||||
| use_cache: Optional[bool] = None, | |||||
| output_attentions: Optional[bool] = None, | |||||
| output_hidden_states: Optional[bool] = None, | |||||
| return_dict: Optional[bool] = None, | |||||
| **kwargs): | |||||
| return self.model.forward( | |||||
| self, input_ids, attention_mask, decoder_input_ids, | |||||
| decoder_attention_mask, head_mask, decoder_head_mask, | |||||
| cross_attn_head_mask, encoder_outputs, past_key_values, | |||||
| inputs_embeds, decoder_inputs_embeds, labels, use_cache, | |||||
| output_attentions, output_hidden_states, return_dict, **kwargs) | |||||
| @@ -32,7 +32,7 @@ if TYPE_CHECKING: | |||||
| from .token_classification import SbertForTokenClassification | from .token_classification import SbertForTokenClassification | ||||
| from .sentence_embedding import SentenceEmbedding | from .sentence_embedding import SentenceEmbedding | ||||
| from .passage_ranking import PassageRanking | from .passage_ranking import PassageRanking | ||||
| from .T5 import T5ForConditionalGeneration | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| @@ -68,6 +68,7 @@ else: | |||||
| 'table_question_answering': ['TableQuestionAnswering'], | 'table_question_answering': ['TableQuestionAnswering'], | ||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'passage_ranking': ['PassageRanking'], | 'passage_ranking': ['PassageRanking'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -390,12 +390,19 @@ TASK_OUTPUTS = { | |||||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | Tasks.text_error_correction: [OutputKeys.OUTPUT], | ||||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | ||||
| Tasks.passage_ranking: [OutputKeys.SCORES], | Tasks.passage_ranking: [OutputKeys.SCORES], | ||||
| # text generation result for single sample | # text generation result for single sample | ||||
| # { | # { | ||||
| # "text": "this is the text generated by a model." | # "text": "this is the text generated by a model." | ||||
| # } | # } | ||||
| Tasks.text_generation: [OutputKeys.TEXT], | Tasks.text_generation: [OutputKeys.TEXT], | ||||
| # text generation result for single sample | |||||
| # { | |||||
| # "text": "北京" | |||||
| # } | |||||
| Tasks.text2text_generation: [OutputKeys.TEXT], | |||||
| # fill mask result for single sample | # fill mask result for single sample | ||||
| # { | # { | ||||
| # "text": "this is the text which masks filled by model." | # "text": "this is the text which masks filled by model." | ||||
| @@ -12,7 +12,7 @@ if TYPE_CHECKING: | |||||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | from .document_segmentation_pipeline import DocumentSegmentationPipeline | ||||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | ||||
| from .fill_mask_pipeline import FillMaskPipeline | from .fill_mask_pipeline import FillMaskPipeline | ||||
| from .fill_mask_ponet_pipeline import FillMaskPoNetPreprocessor | |||||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | |||||
| from .information_extraction_pipeline import InformationExtractionPipeline | from .information_extraction_pipeline import InformationExtractionPipeline | ||||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | ||||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | ||||
| @@ -22,6 +22,7 @@ if TYPE_CHECKING: | |||||
| from .text_classification_pipeline import TextClassificationPipeline | from .text_classification_pipeline import TextClassificationPipeline | ||||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | from .text_error_correction_pipeline import TextErrorCorrectionPipeline | ||||
| from .text_generation_pipeline import TextGenerationPipeline | from .text_generation_pipeline import TextGenerationPipeline | ||||
| from .text2text_generation_pipeline import Text2TextGenerationPipeline | |||||
| from .token_classification_pipeline import TokenClassificationPipeline | from .token_classification_pipeline import TokenClassificationPipeline | ||||
| from .translation_pipeline import TranslationPipeline | from .translation_pipeline import TranslationPipeline | ||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| @@ -54,6 +55,7 @@ else: | |||||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | 'text_classification_pipeline': ['TextClassificationPipeline'], | ||||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | ||||
| 'text_generation_pipeline': ['TextGenerationPipeline'], | 'text_generation_pipeline': ['TextGenerationPipeline'], | ||||
| 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], | |||||
| 'token_classification_pipeline': ['TokenClassificationPipeline'], | 'token_classification_pipeline': ['TokenClassificationPipeline'], | ||||
| 'translation_pipeline': ['TranslationPipeline'], | 'translation_pipeline': ['TranslationPipeline'], | ||||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | ||||
| @@ -0,0 +1,87 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['Text2TextGenerationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text2text_generation, module_name=Pipelines.text2text_generation) | |||||
| class Text2TextGenerationPipeline(Pipeline): | |||||
| def __init__( | |||||
| self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Text2TextGenerationPreprocessor] = None, | |||||
| first_sequence='sentence', | |||||
| **kwargs): | |||||
| """Use `model` and `preprocessor` to create a text to text generation pipeline for prediction. | |||||
| Args: | |||||
| model (str or Model): Supply either a local model dir which supported the text generation task, | |||||
| or a model id from the model hub, or a torch model instance. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. | |||||
| first_sequence: The key to read the first sentence in. | |||||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' | |||||
| param will have no effect. | |||||
| Example: | |||||
| >>> from modelscope.pipelines import pipeline | |||||
| >>> pipeline_ins = pipeline(task='text-generation', | |||||
| >>> model='damo/nlp_palm2.0_text-generation_chinese-base') | |||||
| >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' | |||||
| >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' | |||||
| >>> print(pipeline_ins(sentence1)) | |||||
| >>> # Or use the dict input: | |||||
| >>> print(pipeline_ins({'sentence': sentence1})) | |||||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | |||||
| """ | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = Text2TextGenerationPreprocessor( | |||||
| model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| forward_params['min_length'] = forward_params.get( | |||||
| 'min_length', self.model.config.min_length) | |||||
| forward_params['max_length'] = forward_params.get( | |||||
| 'max_length', self.model.config.max_length) | |||||
| with torch.no_grad(): | |||||
| output_ids = self.model.generate(**inputs, **forward_params) | |||||
| return {'output_ids': output_ids} | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| output = self.tokenizer.decode( | |||||
| inputs['output_ids'][0], | |||||
| skip_special_tokens=True, | |||||
| ) | |||||
| return {OutputKeys.TEXT: output} | |||||
| @@ -24,7 +24,7 @@ if TYPE_CHECKING: | |||||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | ||||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | ||||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | ||||
| PassageRankingPreprocessor, | |||||
| PassageRankingPreprocessor, Text2TextGenerationPreprocessor, | |||||
| WordSegmentationBlankSetToLabelPreprocessor) | WordSegmentationBlankSetToLabelPreprocessor) | ||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| @@ -57,6 +57,7 @@ else: | |||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | ||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'Text2TextGenerationPreprocessor', | |||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | ||||
| ], | ], | ||||
| @@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||||
| Tokenize, SequenceClassificationPreprocessor, | Tokenize, SequenceClassificationPreprocessor, | ||||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | TextGenerationPreprocessor, TokenClassificationPreprocessor, | ||||
| SingleSentenceClassificationPreprocessor, | SingleSentenceClassificationPreprocessor, | ||||
| Text2TextGenerationPreprocessor, | |||||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | ||||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | ZeroShotClassificationPreprocessor, NERPreprocessor, | ||||
| FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, | FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, | ||||
| @@ -27,6 +28,7 @@ else: | |||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | ||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'Text2TextGenerationPreprocessor', | |||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | ||||
| ], | ], | ||||
| @@ -26,6 +26,7 @@ __all__ = [ | |||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | ||||
| 'PairSentenceClassificationPreprocessor', | 'PairSentenceClassificationPreprocessor', | ||||
| 'Text2TextGenerationPreprocessor', | |||||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | ||||
| @@ -442,6 +443,40 @@ class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| return features | return features | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.text2text_gen_preprocessor) | |||||
| class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||||
| """The tokenizer preprocessor used in text generation. | |||||
| """ | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| tokenizer=None, | |||||
| mode=ModeKeys.INFERENCE, | |||||
| **kwargs): | |||||
| self.tokenizer = self.build_tokenizer( | |||||
| model_dir) if tokenizer is None else tokenizer | |||||
| kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') | |||||
| kwargs['padding'] = kwargs.get('padding', False) | |||||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||||
| False) | |||||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||||
| def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: | |||||
| text_a, _, _ = self.parse_text_and_label(data) | |||||
| inputs = self.tokenizer( | |||||
| text_a, | |||||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||||
| **self.tokenize_kwargs) | |||||
| # This is produced by tokenizers but is an invalid generate kwargs | |||||
| if 'token_type_ids' in inputs: | |||||
| del inputs['token_type_ids'] | |||||
| return inputs | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) | Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) | ||||
| class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | ||||
| @@ -97,6 +97,7 @@ class NLPTasks(object): | |||||
| token_classification = 'token-classification' | token_classification = 'token-classification' | ||||
| conversational = 'conversational' | conversational = 'conversational' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| text2text_generation = 'text2text-generation' | |||||
| task_oriented_conversation = 'task-oriented-conversation' | task_oriented_conversation = 'task-oriented-conversation' | ||||
| dialog_intent_prediction = 'dialog-intent-prediction' | dialog_intent_prediction = 'dialog-intent-prediction' | ||||
| dialog_state_tracking = 'dialog-state-tracking' | dialog_state_tracking = 'dialog-state-tracking' | ||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import T5ForConditionalGeneration | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import Text2TextGenerationPipeline | |||||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/t5-cn-base-test' | |||||
| self.input = '中国的首都位于<extra_id_0>。' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_T5(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = T5ForConditionalGeneration(cache_path) | |||||
| preprocessor = Text2TextGenerationPreprocessor(cache_path) | |||||
| pipeline1 = Text2TextGenerationPipeline(model, preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.text2text_generation, model=model, preprocessor=preprocessor) | |||||
| print( | |||||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||||
| ) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_pipeline_with_model_instance(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = Text2TextGenerationPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text2text_generation, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| print(pipeline_ins(self.input)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_pipeline_with_model_id(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text2text_generation, model=self.model_id) | |||||
| print(pipeline_ins(self.input)) | |||||
| @unittest.skip( | |||||
| 'only for test cases, there is no default official model yet') | |||||
| def test_run_pipeline_without_model_id(self): | |||||
| pipeline_ins = pipeline(task=Tasks.text2text_generation) | |||||
| print(pipeline_ins(self.input)) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||