Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10383770 * T5 support translatemaster
| @@ -228,6 +228,9 @@ class Pipelines(object): | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| @@ -314,6 +317,7 @@ class Preprocessors(object): | |||||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | ||||
| text_gen_tokenizer = 'text-gen-tokenizer' | text_gen_tokenizer = 'text-gen-tokenizer' | ||||
| text2text_gen_preprocessor = 'text2text-gen-preprocessor' | text2text_gen_preprocessor = 'text2text-gen-preprocessor' | ||||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||||
| token_cls_tokenizer = 'token-cls-tokenizer' | token_cls_tokenizer = 'token-cls-tokenizer' | ||||
| ner_tokenizer = 'ner-tokenizer' | ner_tokenizer = 'ner-tokenizer' | ||||
| nli_tokenizer = 'nli-tokenizer' | nli_tokenizer = 'nli-tokenizer' | ||||
| @@ -1,21 +1,35 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Optional, Union | |||||
| from typing import Any, Dict, List, Optional, Union | |||||
| import torch | import torch | ||||
| from numpy import isin | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.base import Input, Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import Text2TextGenerationPreprocessor | from modelscope.preprocessors import Text2TextGenerationPreprocessor | ||||
| from modelscope.utils.config import use_task_specific_params | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| __all__ = ['Text2TextGenerationPipeline'] | __all__ = ['Text2TextGenerationPipeline'] | ||||
| TRANSLATE_PIPELINES = [ | |||||
| Pipelines.translation_en_to_de, | |||||
| Pipelines.translation_en_to_ro, | |||||
| Pipelines.translation_en_to_fr, | |||||
| ] | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.text2text_generation, module_name=Pipelines.text2text_generation) | Tasks.text2text_generation, module_name=Pipelines.text2text_generation) | ||||
| @PIPELINES.register_module( | |||||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_de) | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_ro) | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text2text_generation, module_name=Pipelines.translation_en_to_fr) | |||||
| class Text2TextGenerationPipeline(Pipeline): | class Text2TextGenerationPipeline(Pipeline): | ||||
| def __init__( | def __init__( | ||||
| @@ -39,13 +53,13 @@ class Text2TextGenerationPipeline(Pipeline): | |||||
| Example: | Example: | ||||
| >>> from modelscope.pipelines import pipeline | >>> from modelscope.pipelines import pipeline | ||||
| >>> pipeline_ins = pipeline(task='text-generation', | |||||
| >>> model='damo/nlp_palm2.0_text-generation_chinese-base') | |||||
| >>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:' | |||||
| >>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代' | |||||
| >>> pipeline_ins = pipeline(task='text2text-generation', | |||||
| >>> model='damo/nlp_t5_text2text-generation_chinese-base') | |||||
| >>> sentence1 = '中国的首都位于<extra_id_0>。' | |||||
| >>> print(pipeline_ins(sentence1)) | >>> print(pipeline_ins(sentence1)) | ||||
| >>> # Or use the dict input: | >>> # Or use the dict input: | ||||
| >>> print(pipeline_ins({'sentence': sentence1})) | >>> print(pipeline_ins({'sentence': sentence1})) | ||||
| >>> # 北京 | |||||
| To view other examples plese check the tests/pipelines/test_text_generation.py. | To view other examples plese check the tests/pipelines/test_text_generation.py. | ||||
| """ | """ | ||||
| @@ -56,9 +70,22 @@ class Text2TextGenerationPipeline(Pipeline): | |||||
| model.model_dir, | model.model_dir, | ||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| self.tokenizer = preprocessor.tokenizer | self.tokenizer = preprocessor.tokenizer | ||||
| self.pipeline = model.pipeline.type | |||||
| model.eval() | model.eval() | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||||
| """ Provide specific preprocess for text2text generation pipeline in order to handl multi tasks | |||||
| """ | |||||
| if not isinstance(inputs, str): | |||||
| raise ValueError(f'Not supported input type: {type(inputs)}') | |||||
| if self.pipeline in TRANSLATE_PIPELINES: | |||||
| use_task_specific_params(self.model, self.pipeline) | |||||
| inputs = self.model.config.prefix + inputs | |||||
| return super().preprocess(inputs, **preprocess_params) | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -12,7 +12,8 @@ from modelscope.metainfo import Models, Preprocessors | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
| from modelscope.preprocessors.builder import PREPROCESSORS | from modelscope.preprocessors.builder import PREPROCESSORS | ||||
| from modelscope.utils.config import Config, ConfigFields | |||||
| from modelscope.utils.config import (Config, ConfigFields, | |||||
| use_task_specific_params) | |||||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | ||||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | from modelscope.utils.hub import get_model_type, parse_label_mapping | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -633,6 +633,16 @@ def check_config(cfg: Union[str, ConfigDict]): | |||||
| check_attr(ConfigFields.evaluation) | check_attr(ConfigFields.evaluation) | ||||
| def use_task_specific_params(model, task): | |||||
| """Update config with summarization specific params.""" | |||||
| task_specific_params = model.config.task_specific_params | |||||
| if task_specific_params is not None: | |||||
| pars = task_specific_params.get(task, {}) | |||||
| logger.info(f'using task specific params for {task}: {pars}') | |||||
| model.config.update(pars) | |||||
| class JSONIteratorEncoder(json.JSONEncoder): | class JSONIteratorEncoder(json.JSONEncoder): | ||||
| """Implement this method in order that supporting arbitrary iterators, it returns | """Implement this method in order that supporting arbitrary iterators, it returns | ||||
| a serializable object for ``obj``, or calls the base implementation | a serializable object for ``obj``, or calls the base implementation | ||||
| @@ -15,42 +15,44 @@ from modelscope.utils.test_utils import test_level | |||||
| class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/t5-cn-base-test' | |||||
| self.input = '中国的首都位于<extra_id_0>。' | |||||
| self.model_id_generate = 'damo/t5-cn-base-test' | |||||
| self.input_generate = '中国的首都位于<extra_id_0>。' | |||||
| self.model_id_translate = 'damo/t5-translate-base-test' | |||||
| self.input_translate = 'My name is Wolfgang and I live in Berlin' | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_T5(self): | def test_run_T5(self): | ||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = T5ForConditionalGeneration(cache_path) | |||||
| cache_path = snapshot_download(self.model_id_generate) | |||||
| model = T5ForConditionalGeneration.from_pretrained(cache_path) | |||||
| preprocessor = Text2TextGenerationPreprocessor(cache_path) | preprocessor = Text2TextGenerationPreprocessor(cache_path) | ||||
| pipeline1 = Text2TextGenerationPipeline(model, preprocessor) | pipeline1 = Text2TextGenerationPipeline(model, preprocessor) | ||||
| pipeline2 = pipeline( | pipeline2 = pipeline( | ||||
| Tasks.text2text_generation, model=model, preprocessor=preprocessor) | Tasks.text2text_generation, model=model, preprocessor=preprocessor) | ||||
| print( | print( | ||||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||||
| f'pipeline1: {pipeline1(self.input_generate)}\npipeline2: {pipeline2(self.input_generate)}' | |||||
| ) | ) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_pipeline_with_model_instance(self): | def test_run_pipeline_with_model_instance(self): | ||||
| model = Model.from_pretrained(self.model_id) | |||||
| model = Model.from_pretrained(self.model_id_translate) | |||||
| preprocessor = Text2TextGenerationPreprocessor(model.model_dir) | preprocessor = Text2TextGenerationPreprocessor(model.model_dir) | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.text2text_generation, | task=Tasks.text2text_generation, | ||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| print(pipeline_ins(self.input)) | |||||
| print(pipeline_ins(self.input_translate)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_pipeline_with_model_id(self): | def test_run_pipeline_with_model_id(self): | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| task=Tasks.text2text_generation, model=self.model_id) | |||||
| print(pipeline_ins(self.input)) | |||||
| task=Tasks.text2text_generation, model=self.model_id_translate) | |||||
| print(pipeline_ins(self.input_translate)) | |||||
| @unittest.skip( | @unittest.skip( | ||||
| 'only for test cases, there is no default official model yet') | 'only for test cases, there is no default official model yet') | ||||
| def test_run_pipeline_without_model_id(self): | def test_run_pipeline_without_model_id(self): | ||||
| pipeline_ins = pipeline(task=Tasks.text2text_generation) | pipeline_ins = pipeline(task=Tasks.text2text_generation) | ||||
| print(pipeline_ins(self.input)) | |||||
| print(pipeline_ins(self.input_generate)) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||