Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10383770 * T5 support translatemaster
| @@ -228,6 +228,9 @@ class Pipelines(object): | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -314,6 +317,7 @@ class Preprocessors(object): | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| text_gen_tokenizer = 'text-gen-tokenizer' | |||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | |||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| ner_tokenizer = 'ner-tokenizer' | |||
| nli_tokenizer = 'nli-tokenizer' | |||
| @@ -1,21 +1,35 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional, Union | |||
| from typing import Any, Dict, List, Optional, Union | |||
| import torch | |||
| from numpy import isin | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.base import Input, Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | |||
| from modelscope.utils.config import use_task_specific_params | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['Text2TextGenerationPipeline'] | |||
| TRANSLATE_PIPELINES = [ | |||
| Pipelines.translation_en_to_de, | |||
| Pipelines.translation_en_to_ro, | |||
| Pipelines.translation_en_to_fr, | |||
| ] | |||
| @PIPELINES.register_module( | |||
| Tasks.text2text_generation, module_name=Pipelines.text2text_generation) | |||
| @PIPELINES.register_module( | |||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_de) | |||
| @PIPELINES.register_module( | |||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_ro) | |||
| @PIPELINES.register_module( | |||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_fr) | |||
| class Text2TextGenerationPipeline(Pipeline): | |||
| def __init__( | |||
| @@ -39,13 +53,13 @@ class Text2TextGenerationPipeline(Pipeline): | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipeline_ins = pipeline(task='text-generation', | |||
| >>> model='damo/nlp_palm2.0_text-generation_chinese-base') | |||
| >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' | |||
| >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' | |||
| >>> pipeline_ins = pipeline(task='text2text-generation', | |||
| >>> model='damo/nlp_t5_text2text-generation_chinese-base') | |||
| >>> sentence1 = '中国的首都位于<extra_id_0>。' | |||
| >>> print(pipeline_ins(sentence1)) | |||
| >>> # Or use the dict input: | |||
| >>> print(pipeline_ins({'sentence': sentence1})) | |||
| >>> # 北京 | |||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | |||
| """ | |||
| @@ -56,9 +70,22 @@ class Text2TextGenerationPipeline(Pipeline): | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.pipeline = model.pipeline.type | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||
| """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks | |||
| """ | |||
| if not isinstance(inputs, str): | |||
| raise ValueError(f'Not supported input type: {type(inputs)}') | |||
| if self.pipeline in TRANSLATE_PIPELINES: | |||
| use_task_specific_params(self.model, self.pipeline) | |||
| inputs = self.model.config.prefix + inputs | |||
| return super().preprocess(inputs, **preprocess_params) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| @@ -12,7 +12,8 @@ from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.config import (Config, ConfigFields, | |||
| use_task_specific_params) | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -633,6 +633,16 @@ def check_config(cfg: Union[str, ConfigDict]): | |||
| check_attr(ConfigFields.evaluation) | |||
| def use_task_specific_params(model, task): | |||
| """Update config with summarization specific params.""" | |||
| task_specific_params = model.config.task_specific_params | |||
| if task_specific_params is not None: | |||
| pars = task_specific_params.get(task, {}) | |||
| logger.info(f'using task specific params for {task}: {pars}') | |||
| model.config.update(pars) | |||
| class JSONIteratorEncoder(json.JSONEncoder): | |||
| """Implement this method in order that supporting arbitrary iterators, it returns | |||
| a serializable object for ``obj``, or calls the base implementation | |||
| @@ -15,42 +15,44 @@ from modelscope.utils.test_utils import test_level | |||
| class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/t5-cn-base-test' | |||
| self.input = '中国的首都位于<extra_id_0>。' | |||
| self.model_id_generate = 'damo/t5-cn-base-test' | |||
| self.input_generate = '中国的首都位于<extra_id_0>。' | |||
| self.model_id_translate = 'damo/t5-translate-base-test' | |||
| self.input_translate = 'My name is Wolfgang and I live in Berlin' | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_T5(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = T5ForConditionalGeneration(cache_path) | |||
| cache_path = snapshot_download(self.model_id_generate) | |||
| model = T5ForConditionalGeneration.from_pretrained(cache_path) | |||
| preprocessor = Text2TextGenerationPreprocessor(cache_path) | |||
| pipeline1 = Text2TextGenerationPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.text2text_generation, model=model, preprocessor=preprocessor) | |||
| print( | |||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||
| f'pipeline1: {pipeline1(self.input_generate)}\npipeline2: {pipeline2(self.input_generate)}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_pipeline_with_model_instance(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| model = Model.from_pretrained(self.model_id_translate) | |||
| preprocessor = Text2TextGenerationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text2text_generation, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_ins(self.input)) | |||
| print(pipeline_ins(self.input_translate)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_pipeline_with_model_id(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text2text_generation, model=self.model_id) | |||
| print(pipeline_ins(self.input)) | |||
| task=Tasks.text2text_generation, model=self.model_id_translate) | |||
| print(pipeline_ins(self.input_translate)) | |||
| @unittest.skip( | |||
| 'only for test cases, there is no default official model yet') | |||
| def test_run_pipeline_without_model_id(self): | |||
| pipeline_ins = pipeline(task=Tasks.text2text_generation) | |||
| print(pipeline_ins(self.input)) | |||
| print(pipeline_ins(self.input_generate)) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||