Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10191736 * add T5 for generationmaster
| @@ -65,6 +65,7 @@ class Models(object): | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| ponet = 'ponet' | |||
| T5 = 'T5' | |||
| # audio models | |||
| sambert_hifigan = 'sambert-hifigan' | |||
| @@ -179,6 +180,7 @@ class Pipelines(object): | |||
| part_of_speech = 'part-of-speech' | |||
| named_entity_recognition = 'named-entity-recognition' | |||
| text_generation = 'text-generation' | |||
| text2text_generation = 'text2text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentiment_classification = 'sentiment-classification' | |||
| text_classification = 'text-classification' | |||
| @@ -280,6 +282,7 @@ class Preprocessors(object): | |||
| cross_encoder_tokenizer = 'cross-encoder-tokenizer' | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| text_gen_tokenizer = 'text-gen-tokenizer' | |||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| ner_tokenizer = 'ner-tokenizer' | |||
| nli_tokenizer = 'nli-tokenizer' | |||
| @@ -0,0 +1,21 @@ | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .t5_for_text_generation import T5ForConditionalGeneration | |||
| else: | |||
| _import_structure = { | |||
| 't5_for_text_generation': ['T5ForConditionalGeneration'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,174 @@ | |||
| # Copyright 2020, The T5 Authors and HuggingFace Inc. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ T5 model configuration""" | |||
| from typing import Mapping | |||
| from transformers.configuration_utils import PretrainedConfig | |||
| from transformers.onnx import OnnxSeq2SeqConfigWithPast | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| class T5Config(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to | |||
| instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a | |||
| configuration with the defaults will yield a similar configuration to that of the T5 | |||
| [t5-small](https://huggingface.co/t5-small) architecture. | |||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
| documentation from [`PretrainedConfig`] for more information. | |||
| Arguments: | |||
| vocab_size (`int`, *optional*, defaults to 32128): | |||
| Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the | |||
| `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. | |||
| d_model (`int`, *optional*, defaults to 512): | |||
| Size of the encoder layers and the pooler layer. | |||
| d_kv (`int`, *optional*, defaults to 64): | |||
| Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // | |||
| num_heads`. | |||
| d_ff (`int`, *optional*, defaults to 2048): | |||
| Size of the intermediate feed forward layer in each `T5Block`. | |||
| num_layers (`int`, *optional*, defaults to 6): | |||
| Number of hidden layers in the Transformer encoder. | |||
| num_decoder_layers (`int`, *optional*): | |||
| Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. | |||
| num_heads (`int`, *optional*, defaults to 8): | |||
| Number of attention heads for each attention layer in the Transformer encoder. | |||
| relative_attention_num_buckets (`int`, *optional*, defaults to 32): | |||
| The number of buckets to use for each attention layer. | |||
| relative_attention_max_distance (`int`, *optional*, defaults to 128): | |||
| The maximum distance of the longer sequences for the bucket separation. | |||
| dropout_rate (`float`, *optional*, defaults to 0.1): | |||
| The ratio for all dropout layers. | |||
| layer_norm_eps (`float`, *optional*, defaults to 1e-6): | |||
| The epsilon used by the layer normalization layers. | |||
| initializer_factor (`float`, *optional*, defaults to 1): | |||
| A factor for initializing all weight matrices (should be kept to 1, used internally for initialization | |||
| testing). | |||
| feed_forward_proj (`string`, *optional*, defaults to `"relu"`): | |||
| Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the | |||
| `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. | |||
| use_cache (`bool`, *optional*, defaults to `True`): | |||
| Whether or not the model should return the last key/values attentions (not used by all models). | |||
| """ | |||
| model_type = 't5' | |||
| keys_to_ignore_at_inference = ['past_key_values'] | |||
| attribute_map = { | |||
| 'hidden_size': 'd_model', | |||
| 'num_attention_heads': 'num_heads', | |||
| 'num_hidden_layers': 'num_layers' | |||
| } | |||
| def __init__(self, | |||
| vocab_size=32128, | |||
| d_model=512, | |||
| d_kv=64, | |||
| d_ff=2048, | |||
| num_layers=6, | |||
| num_decoder_layers=None, | |||
| num_heads=8, | |||
| relative_attention_num_buckets=32, | |||
| relative_attention_max_distance=128, | |||
| dropout_rate=0.1, | |||
| layer_norm_epsilon=1e-6, | |||
| initializer_factor=1.0, | |||
| feed_forward_proj='relu', | |||
| is_encoder_decoder=True, | |||
| use_cache=True, | |||
| pad_token_id=0, | |||
| eos_token_id=1, | |||
| **kwargs): | |||
| self.vocab_size = vocab_size | |||
| self.d_model = d_model | |||
| self.d_kv = d_kv | |||
| self.d_ff = d_ff | |||
| self.num_layers = num_layers | |||
| self.num_decoder_layers = (num_decoder_layers if num_decoder_layers | |||
| is not None else self.num_layers | |||
| ) # default = symmetry | |||
| self.num_heads = num_heads | |||
| self.relative_attention_num_buckets = relative_attention_num_buckets | |||
| self.relative_attention_max_distance = relative_attention_max_distance | |||
| self.dropout_rate = dropout_rate | |||
| self.layer_norm_epsilon = layer_norm_epsilon | |||
| self.initializer_factor = initializer_factor | |||
| self.feed_forward_proj = feed_forward_proj | |||
| self.use_cache = use_cache | |||
| act_info = self.feed_forward_proj.split('-') | |||
| self.dense_act_fn = act_info[-1] | |||
| self.is_gated_act = act_info[0] == 'gated' | |||
| if len(act_info) > 1 and act_info[0] != 'gated' or len(act_info) > 2: | |||
| raise ValueError( | |||
| f'`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer.' | |||
| 'Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. ' | |||
| "'gated-gelu' or 'relu'") | |||
| # for backwards compatibility | |||
| if feed_forward_proj == 'gated-gelu': | |||
| self.dense_act_fn = 'gelu_new' | |||
| super().__init__( | |||
| pad_token_id=pad_token_id, | |||
| eos_token_id=eos_token_id, | |||
| is_encoder_decoder=is_encoder_decoder, | |||
| **kwargs, | |||
| ) | |||
| class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): | |||
| @property | |||
| def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| common_inputs = { | |||
| 'input_ids': { | |||
| 0: 'batch', | |||
| 1: 'encoder_sequence' | |||
| }, | |||
| 'attention_mask': { | |||
| 0: 'batch', | |||
| 1: 'encoder_sequence' | |||
| }, | |||
| } | |||
| if self.use_past: | |||
| common_inputs['attention_mask'][ | |||
| 1] = 'past_encoder_sequence + sequence' | |||
| common_inputs['decoder_input_ids'] = {0: 'batch'} | |||
| common_inputs['decoder_attention_mask'] = { | |||
| 0: 'batch', | |||
| 1: 'past_decoder_sequence + sequence' | |||
| } | |||
| else: | |||
| common_inputs['decoder_input_ids'] = { | |||
| 0: 'batch', | |||
| 1: 'decoder_sequence' | |||
| } | |||
| common_inputs['decoder_attention_mask'] = { | |||
| 0: 'batch', | |||
| 1: 'decoder_sequence' | |||
| } | |||
| if self.use_past: | |||
| self.fill_with_past_key_values_(common_inputs, direction='inputs') | |||
| return common_inputs | |||
| @property | |||
| def default_onnx_opset(self) -> int: | |||
| return 13 | |||
| @@ -0,0 +1,56 @@ | |||
| from typing import Optional, Tuple | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from .modeling_t5 import T5Config | |||
| from .modeling_t5 import T5ForConditionalGeneration as T5ForGeneration | |||
| @MODELS.register_module( | |||
| group_key=Tasks.text2text_generation, | |||
| module_name=Models.T5, | |||
| ) | |||
| class T5ForConditionalGeneration(TorchModel): | |||
| def __init__(self, model_dir=None, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model = T5ForGeneration.from_pretrained(model_dir) | |||
| self.generate = self.model.generate | |||
| self.config = self.model.config | |||
| def forward(self, | |||
| input_ids: Optional[torch.LongTensor] = None, | |||
| attention_mask: Optional[torch.FloatTensor] = None, | |||
| decoder_input_ids: Optional[torch.LongTensor] = None, | |||
| decoder_attention_mask: Optional[torch.BoolTensor] = None, | |||
| head_mask: Optional[torch.FloatTensor] = None, | |||
| decoder_head_mask: Optional[torch.FloatTensor] = None, | |||
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |||
| encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
| inputs_embeds: Optional[torch.FloatTensor] = None, | |||
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |||
| labels: Optional[torch.LongTensor] = None, | |||
| use_cache: Optional[bool] = None, | |||
| output_attentions: Optional[bool] = None, | |||
| output_hidden_states: Optional[bool] = None, | |||
| return_dict: Optional[bool] = None, | |||
| **kwargs): | |||
| return self.model.forward( | |||
| self, input_ids, attention_mask, decoder_input_ids, | |||
| decoder_attention_mask, head_mask, decoder_head_mask, | |||
| cross_attn_head_mask, encoder_outputs, past_key_values, | |||
| inputs_embeds, decoder_inputs_embeds, labels, use_cache, | |||
| output_attentions, output_hidden_states, return_dict, **kwargs) | |||
| @@ -32,7 +32,7 @@ if TYPE_CHECKING: | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sentence_embedding import SentenceEmbedding | |||
| from .passage_ranking import PassageRanking | |||
| from .T5 import T5ForConditionalGeneration | |||
| else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| @@ -68,6 +68,7 @@ else: | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'passage_ranking': ['PassageRanking'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| } | |||
| import sys | |||
| @@ -390,12 +390,19 @@ TASK_OUTPUTS = { | |||
| Tasks.text_error_correction: [OutputKeys.OUTPUT], | |||
| Tasks.sentence_embedding: [OutputKeys.TEXT_EMBEDDING, OutputKeys.SCORES], | |||
| Tasks.passage_ranking: [OutputKeys.SCORES], | |||
| # text generation result for single sample | |||
| # { | |||
| # "text": "this is the text generated by a model." | |||
| # } | |||
| Tasks.text_generation: [OutputKeys.TEXT], | |||
| # text generation result for single sample | |||
| # { | |||
| # "text": "北京" | |||
| # } | |||
| Tasks.text2text_generation: [OutputKeys.TEXT], | |||
| # fill mask result for single sample | |||
| # { | |||
| # "text": "this is the text which masks filled by model." | |||
| @@ -12,7 +12,7 @@ if TYPE_CHECKING: | |||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .fill_mask_ponet_pipeline import FillMaskPoNetPreprocessor | |||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | |||
| @@ -22,6 +22,7 @@ if TYPE_CHECKING: | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
| from .text_generation_pipeline import TextGenerationPipeline | |||
| from .text2text_generation_pipeline import Text2TextGenerationPipeline | |||
| from .token_classification_pipeline import TokenClassificationPipeline | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| @@ -54,6 +55,7 @@ else: | |||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
| 'text_generation_pipeline': ['TextGenerationPipeline'], | |||
| 'text2text_generation_pipeline': ['Text2TextGenerationPipeline'], | |||
| 'token_classification_pipeline': ['TokenClassificationPipeline'], | |||
| 'translation_pipeline': ['TranslationPipeline'], | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| @@ -0,0 +1,87 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['Text2TextGenerationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text2text_generation, module_name=Pipelines.text2text_generation) | |||
| class Text2TextGenerationPipeline(Pipeline): | |||
| def __init__( | |||
| self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Text2TextGenerationPreprocessor] = None, | |||
| first_sequence='sentence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a text to text generation pipeline for prediction. | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported the text generation task, | |||
| or a model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| first_sequence: The key to read the first sentence in. | |||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' | |||
| param will have no effect. | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipeline_ins = pipeline(task='text-generation', | |||
| >>> model='damo/nlp_palm2.0_text-generation_chinese-base') | |||
| >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' | |||
| >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' | |||
| >>> print(pipeline_ins(sentence1)) | |||
| >>> # Or use the dict input: | |||
| >>> print(pipeline_ins({'sentence': sentence1})) | |||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = Text2TextGenerationPreprocessor( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| forward_params['min_length'] = forward_params.get( | |||
| 'min_length', self.model.config.min_length) | |||
| forward_params['max_length'] = forward_params.get( | |||
| 'max_length', self.model.config.max_length) | |||
| with torch.no_grad(): | |||
| output_ids = self.model.generate(**inputs, **forward_params) | |||
| return {'output_ids': output_ids} | |||
| def postprocess(self, inputs: Dict[str, Tensor], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| output = self.tokenizer.decode( | |||
| inputs['output_ids'][0], | |||
| skip_special_tokens=True, | |||
| ) | |||
| return {OutputKeys.TEXT: output} | |||
| @@ -24,7 +24,7 @@ if TYPE_CHECKING: | |||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | |||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | |||
| PassageRankingPreprocessor, | |||
| PassageRankingPreprocessor, Text2TextGenerationPreprocessor, | |||
| WordSegmentationBlankSetToLabelPreprocessor) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| @@ -57,6 +57,7 @@ else: | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| ], | |||
| @@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||
| Tokenize, SequenceClassificationPreprocessor, | |||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||
| SingleSentenceClassificationPreprocessor, | |||
| Text2TextGenerationPreprocessor, | |||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, | |||
| @@ -27,6 +28,7 @@ else: | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| ], | |||
| @@ -26,6 +26,7 @@ __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| @@ -442,6 +443,40 @@ class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| return features | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.text2text_gen_preprocessor) | |||
| class Text2TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in text generation. | |||
| """ | |||
| def __init__(self, | |||
| model_dir: str, | |||
| tokenizer=None, | |||
| mode=ModeKeys.INFERENCE, | |||
| **kwargs): | |||
| self.tokenizer = self.build_tokenizer( | |||
| model_dir) if tokenizer is None else tokenizer | |||
| kwargs['truncation'] = kwargs.get('truncation', 'do_not_truncate') | |||
| kwargs['padding'] = kwargs.get('padding', False) | |||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||
| False) | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: | |||
| text_a, _, _ = self.parse_text_and_label(data) | |||
| inputs = self.tokenizer( | |||
| text_a, | |||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||
| **self.tokenize_kwargs) | |||
| # This is produced by tokenizers but is an invalid generate kwargs | |||
| if 'token_type_ids' in inputs: | |||
| del inputs['token_type_ids'] | |||
| return inputs | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.text_gen_tokenizer) | |||
| class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| @@ -97,6 +97,7 @@ class NLPTasks(object): | |||
| token_classification = 'token-classification' | |||
| conversational = 'conversational' | |||
| text_generation = 'text-generation' | |||
| text2text_generation = 'text2text-generation' | |||
| task_oriented_conversation = 'task-oriented-conversation' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import T5ForConditionalGeneration | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import Text2TextGenerationPipeline | |||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/t5-cn-base-test' | |||
| self.input = '中国的首都位于<extra_id_0>。' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_T5(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = T5ForConditionalGeneration(cache_path) | |||
| preprocessor = Text2TextGenerationPreprocessor(cache_path) | |||
| pipeline1 = Text2TextGenerationPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.text2text_generation, model=model, preprocessor=preprocessor) | |||
| print( | |||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_pipeline_with_model_instance(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = Text2TextGenerationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text2text_generation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_ins(self.input)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_pipeline_with_model_id(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text2text_generation, model=self.model_id) | |||
| print(pipeline_ins(self.input)) | |||
| @unittest.skip( | |||
| 'only for test cases, there is no default official model yet') | |||
| def test_run_pipeline_without_model_id(self): | |||
| pipeline_ins = pipeline(task=Tasks.text2text_generation) | |||
| print(pipeline_ins(self.input)) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||