Browse Source

[to #42322933] support t5_with_translation

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10383770

    * T5 support translate
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
0eb823b764
5 changed files with 63 additions and 19 deletions
  1. +4
    -0
      modelscope/metainfo.py
  2. +33
    -6
      modelscope/pipelines/nlp/text2text_generation_pipeline.py
  3. +2
    -1
      modelscope/preprocessors/nlp/nlp_base.py
  4. +10
    -0
      modelscope/utils/config.py
  5. +14
    -12
      tests/pipelines/test_text2text_generation.py

+ 4
- 0
modelscope/metainfo.py View File

@@ -228,6 +228,9 @@ class Pipelines(object):
relation_extraction = 'relation-extraction' relation_extraction = 'relation-extraction'
document_segmentation = 'document-segmentation' document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction' feature_extraction = 'feature-extraction'
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore


# audio tasks # audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts' sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -314,6 +317,7 @@ class Preprocessors(object):
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
text_gen_tokenizer = 'text-gen-tokenizer' text_gen_tokenizer = 'text-gen-tokenizer'
text2text_gen_preprocessor = 'text2text-gen-preprocessor' text2text_gen_preprocessor = 'text2text-gen-preprocessor'
text2text_translate_preprocessor = 'text2text-translate-preprocessor'
token_cls_tokenizer = 'token-cls-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer'
ner_tokenizer = 'ner-tokenizer' ner_tokenizer = 'ner-tokenizer'
nli_tokenizer = 'nli-tokenizer' nli_tokenizer = 'nli-tokenizer'


+ 33
- 6
modelscope/pipelines/nlp/text2text_generation_pipeline.py View File

@@ -1,21 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union


import torch import torch
from numpy import isin


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.base import Input, Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Text2TextGenerationPreprocessor from modelscope.preprocessors import Text2TextGenerationPreprocessor
from modelscope.utils.config import use_task_specific_params
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks


__all__ = ['Text2TextGenerationPipeline'] __all__ = ['Text2TextGenerationPipeline']


TRANSLATE_PIPELINES = [
Pipelines.translation_en_to_de,
Pipelines.translation_en_to_ro,
Pipelines.translation_en_to_fr,
]



@PIPELINES.register_module( @PIPELINES.register_module(
Tasks.text2text_generation, module_name=Pipelines.text2text_generation) Tasks.text2text_generation, module_name=Pipelines.text2text_generation)
@PIPELINES.register_module(
Tasks.text2text_generation, module_name=Pipelines.translation_en_to_de)
@PIPELINES.register_module(
Tasks.text2text_generation, module_name=Pipelines.translation_en_to_ro)
@PIPELINES.register_module(
Tasks.text2text_generation, module_name=Pipelines.translation_en_to_fr)
class Text2TextGenerationPipeline(Pipeline): class Text2TextGenerationPipeline(Pipeline):


def __init__( def __init__(
@@ -39,13 +53,13 @@ class Text2TextGenerationPipeline(Pipeline):


Example: Example:
>>> from modelscope.pipelines import pipeline >>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='text-generation',
>>> model='damo/nlp_palm2.0_text-generation_chinese-base')
>>> sentence1 = '本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:'
>>> '1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代'
>>> pipeline_ins = pipeline(task='text2text-generation',
>>> model='damo/nlp_t5_text2text-generation_chinese-base')
>>> sentence1 = '中国的首都位于<extra_id_0>。'
>>> print(pipeline_ins(sentence1)) >>> print(pipeline_ins(sentence1))
>>> # Or use the dict input: >>> # Or use the dict input:
>>> print(pipeline_ins({'sentence': sentence1})) >>> print(pipeline_ins({'sentence': sentence1}))
>>> # 北京


To view other examples plese check the tests/pipelines/test_text_generation.py. To view other examples plese check the tests/pipelines/test_text_generation.py.
""" """
@@ -56,9 +70,22 @@ class Text2TextGenerationPipeline(Pipeline):
model.model_dir, model.model_dir,
sequence_length=kwargs.pop('sequence_length', 128)) sequence_length=kwargs.pop('sequence_length', 128))
self.tokenizer = preprocessor.tokenizer self.tokenizer = preprocessor.tokenizer
self.pipeline = model.pipeline.type
model.eval() model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)


def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
""" Provide specific preprocess for text2text generation pipeline in order to handl multi tasks
"""
if not isinstance(inputs, str):
raise ValueError(f'Not supported input type: {type(inputs)}')

if self.pipeline in TRANSLATE_PIPELINES:
use_task_specific_params(self.model, self.pipeline)
inputs = self.model.config.prefix + inputs

return super().preprocess(inputs, **preprocess_params)

def forward(self, inputs: Dict[str, Any], def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]: **forward_params) -> Dict[str, Any]:




+ 2
- 1
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -12,7 +12,8 @@ from modelscope.metainfo import Models, Preprocessors
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.config import Config, ConfigFields
from modelscope.utils.config import (Config, ConfigFields,
use_task_specific_params)
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile
from modelscope.utils.hub import get_model_type, parse_label_mapping from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


+ 10
- 0
modelscope/utils/config.py View File

@@ -633,6 +633,16 @@ def check_config(cfg: Union[str, ConfigDict]):
check_attr(ConfigFields.evaluation) check_attr(ConfigFields.evaluation)




def use_task_specific_params(model, task):
"""Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params

if task_specific_params is not None:
pars = task_specific_params.get(task, {})
logger.info(f'using task specific params for {task}: {pars}')
model.config.update(pars)


class JSONIteratorEncoder(json.JSONEncoder): class JSONIteratorEncoder(json.JSONEncoder):
"""Implement this method in order that supporting arbitrary iterators, it returns """Implement this method in order that supporting arbitrary iterators, it returns
a serializable object for ``obj``, or calls the base implementation a serializable object for ``obj``, or calls the base implementation


+ 14
- 12
tests/pipelines/test_text2text_generation.py View File

@@ -15,42 +15,44 @@ from modelscope.utils.test_utils import test_level
class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): class Text2TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):


def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/t5-cn-base-test'
self.input = '中国的首都位于<extra_id_0>。'
self.model_id_generate = 'damo/t5-cn-base-test'
self.input_generate = '中国的首都位于<extra_id_0>。'
self.model_id_translate = 'damo/t5-translate-base-test'
self.input_translate = 'My name is Wolfgang and I live in Berlin'


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_T5(self): def test_run_T5(self):
cache_path = snapshot_download(self.model_id)
model = T5ForConditionalGeneration(cache_path)
cache_path = snapshot_download(self.model_id_generate)
model = T5ForConditionalGeneration.from_pretrained(cache_path)
preprocessor = Text2TextGenerationPreprocessor(cache_path) preprocessor = Text2TextGenerationPreprocessor(cache_path)
pipeline1 = Text2TextGenerationPipeline(model, preprocessor) pipeline1 = Text2TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline( pipeline2 = pipeline(
Tasks.text2text_generation, model=model, preprocessor=preprocessor) Tasks.text2text_generation, model=model, preprocessor=preprocessor)
print( print(
f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}'
f'pipeline1: {pipeline1(self.input_generate)}\npipeline2: {pipeline2(self.input_generate)}'
) )


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline_with_model_instance(self): def test_run_pipeline_with_model_instance(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(self.model_id_translate)
preprocessor = Text2TextGenerationPreprocessor(model.model_dir) preprocessor = Text2TextGenerationPreprocessor(model.model_dir)
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.text2text_generation, task=Tasks.text2text_generation,
model=model, model=model,
preprocessor=preprocessor) preprocessor=preprocessor)
print(pipeline_ins(self.input))
print(pipeline_ins(self.input_translate))


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline_with_model_id(self): def test_run_pipeline_with_model_id(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.text2text_generation, model=self.model_id)
print(pipeline_ins(self.input))
task=Tasks.text2text_generation, model=self.model_id_translate)
print(pipeline_ins(self.input_translate))


@unittest.skip( @unittest.skip(
'only for test cases, there is no default official model yet') 'only for test cases, there is no default official model yet')
def test_run_pipeline_without_model_id(self): def test_run_pipeline_without_model_id(self):
pipeline_ins = pipeline(task=Tasks.text2text_generation) pipeline_ins = pipeline(task=Tasks.text2text_generation)
print(pipeline_ins(self.input))
print(pipeline_ins(self.input_generate))


@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self): def test_demo_compatibility(self):


Loading…
Cancel
Save