Browse Source

[to #42322933][NLP] Add text error correction task

NLP新增文本纠错任务
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9540716
master
klayzhang.zb yingda.chen 3 years ago
parent
commit
ca4b5b2565
12 changed files with 283 additions and 6 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +93
    -0
      modelscope/models/nlp/bart_for_text_error_correction.py
  4. +7
    -2
      modelscope/outputs.py
  5. +3
    -0
      modelscope/pipelines/builder.py
  6. +2
    -0
      modelscope/pipelines/nlp/__init__.py
  7. +69
    -0
      modelscope/pipelines/nlp/text_error_correction_pipeline.py
  8. +4
    -2
      modelscope/preprocessors/__init__.py
  9. +43
    -2
      modelscope/preprocessors/nlp.py
  10. +1
    -0
      modelscope/utils/constant.py
  11. +1
    -0
      requirements/nlp.txt
  12. +55
    -0
      tests/pipelines/test_text_error_correction.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -24,6 +24,7 @@ class Models(object):
translation = 'csanmt-translation'
space = 'space'
tcrf = 'transformer-crf'
bart = 'bart'

# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -98,6 +99,7 @@ class Pipelines(object):
dialog_modeling = 'dialog-modeling'
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'

# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -161,6 +163,7 @@ class Preprocessors(object):
dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'
text_error_correction = 'text-error-correction'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .space_for_dialog_modeling import SpaceForDialogModeling
from .space_for_dialog_state_tracking import SpaceForDialogStateTracking
from .task_model import SingleBackboneTaskModelBase
from .bart_for_text_error_correction import BartForTextErrorCorrection

else:
_import_structure = {
@@ -46,6 +47,7 @@ else:
'space_for_dialog_modeling': ['SpaceForDialogModeling'],
'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'],
'task_model': ['SingleBackboneTaskModelBase'],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
}

import sys


+ 93
- 0
modelscope/models/nlp/bart_for_text_error_correction.py View File

@@ -0,0 +1,93 @@
import os.path as osp
from typing import Any, Dict

import torch.cuda

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['BartForTextErrorCorrection']


@MODELS.register_module(Tasks.text_error_correction, module_name=Models.bart)
class BartForTextErrorCorrection(TorchModel):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
"""initialize the text error correction model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
ckpt_name = ModelFile.TORCH_MODEL_FILE
local_model = osp.join(model_dir, ckpt_name)

bart_vocab_dir = model_dir
# turn on cuda if GPU is available
from fairseq import checkpoint_utils, utils
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\
else False

overrides = {
'data': bart_vocab_dir,
'beam': 2,
}
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides)
# Move models to GPU
for model in models:
model.eval()
model.to(self._device)
if self.use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
# Initialize generator
self.generator = task.build_generator(models, 'translation')

self.task = task

def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data
Example:
1 sent:
{'net_input':
{'src_tokens':tensor([2478,242,24,4]),
'src_lengths': tensor([4])}
}


Returns:
Dict[str, Tensor]: results
Example:
{
'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer
}
"""
import fairseq.utils

if len(input['net_input']['src_tokens'].size()) == 1:
input['net_input']['src_tokens'] = input['net_input'][
'src_tokens'].view(1, -1)

if torch.cuda.is_available():
input = fairseq.utils.move_to_cuda(input, device=self._device)

sample = input

translations = self.task.inference_step(self.generator, self.models,
sample)

# get 1-best List[Tensor]
preds = translations[0][0]['tokens']
return {'predictions': preds}

+ 7
- 2
modelscope/outputs.py View File

@@ -352,10 +352,15 @@ TASK_OUTPUTS = {
# "text": "this is the text generated by a model."
# }
Tasks.visual_question_answering: [OutputKeys.TEXT],

# auto_speech_recognition result for a single sample
# {
# "text": "每天都要快乐喔"
# }
Tasks.auto_speech_recognition: [OutputKeys.TEXT]
Tasks.auto_speech_recognition: [OutputKeys.TEXT],

# text_error_correction result for a single sample
# {
# "output": "我想吃苹果"
# }
Tasks.text_error_correction: [OutputKeys.OUTPUT]
}

+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -51,6 +51,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.text_error_correction:
(Pipelines.text_error_correction,
'damo/nlp_bart_text-error-correction_chinese'),
Tasks.image_captioning: (Pipelines.image_captioning,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:


+ 2
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from .translation_pipeline import TranslationPipeline
from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline

else:
_import_structure = {
@@ -37,6 +38,7 @@ else:
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'translation_pipeline': ['TranslationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline']
}

import sys


+ 69
- 0
modelscope/pipelines/nlp/text_error_correction_pipeline.py View File

@@ -0,0 +1,69 @@
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import BartForTextErrorCorrection
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextErrorCorrectionPreprocessor
from modelscope.utils.constant import Tasks

__all__ = ['TextErrorCorrectionPipeline']


@PIPELINES.register_module(
Tasks.text_error_correction, module_name=Pipelines.text_error_correction)
class TextErrorCorrectionPipeline(Pipeline):

def __init__(
self,
model: Union[BartForTextErrorCorrection, str],
preprocessor: Optional[TextErrorCorrectionPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction

Args:
model (BartForTextErrorCorrection): a model instance
preprocessor (TextErrorCorrectionPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
BartForTextErrorCorrection) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TextErrorCorrectionPreprocessor(model.model_dir)
self.vocab = preprocessor.vocab
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
"""
Args:
inputs (Dict[str, Tensor])
Example:
{
'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer
}
Returns:
Dict[str, str]
Example:
{
'output': '随着中国经济突飞猛进,建造工业与日俱增'
}


"""

pred_str = self.vocab.string(
inputs['predictions'],
'@@',
extra_symbols_to_ignore={self.vocab.pad()})

return {OutputKeys.OUTPUT: ''.join(pred_str.split())}

+ 4
- 2
modelscope/preprocessors/__init__.py View File

@@ -21,7 +21,8 @@ if TYPE_CHECKING:
TokenClassificationPreprocessor, NLIPreprocessor,
SentimentClassificationPreprocessor,
SentenceSimilarityPreprocessor, FillMaskPreprocessor,
ZeroShotClassificationPreprocessor, NERPreprocessor)
ZeroShotClassificationPreprocessor, NERPreprocessor,
TextErrorCorrectionPreprocessor)
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
@@ -49,7 +50,8 @@ else:
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
'NLIPreprocessor', 'SentimentClassificationPreprocessor',
'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor'
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor'
],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',


+ 43
- 2
modelscope/preprocessors/nlp.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
import uuid
from typing import Any, Dict, Union

@@ -17,8 +18,9 @@ __all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
'NLIPreprocessor', 'SentimentClassificationPreprocessor',
'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor'
'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor'
]


@@ -431,3 +433,42 @@ class NERPreprocessor(Preprocessor):
'label_mask': label_mask,
'offset_mapping': offset_mapping
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.text_error_correction)
class TextErrorCorrectionPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
from fairseq.data import Dictionary
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))

def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'随着中国经济突飞猛近,建造工业与日俱增'
Returns:
Dict[str, Any]: the preprocessed data
Example:
{'net_input':
{'src_tokens':tensor([1,2,3,4]),
'src_lengths': tensor([4])}
}
"""

text = ' '.join([x for x in data])
inputs = self.vocab.encode_line(
text, append_eos=True, add_if_not_exist=False)
lengths = inputs.size()
sample = dict()
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths}
return sample

+ 1
- 0
modelscope/utils/constant.py View File

@@ -67,6 +67,7 @@ class NLPTasks(object):
question_answering = 'question-answering'
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'


class AudioTasks(object):


+ 1
- 0
requirements/nlp.txt View File

@@ -1,4 +1,5 @@
en_core_web_sm>=2.3.5
fairseq>=0.10.2
pai-easynlp
# rough-score was just recently updated from 0.0.4 to 0.0.7
# which introduced compatability issues that are being investigated


+ 55
- 0
tests/pipelines/test_text_error_correction.py View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import BartForTextErrorCorrection
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextErrorCorrectionPipeline
from modelscope.preprocessors import TextErrorCorrectionPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TextErrorCorrectionTest(unittest.TestCase):
model_id = 'damo/nlp_bart_text-error-correction_chinese'
input = '随着中国经济突飞猛近,建造工业与日俱增'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_download(self):
cache_path = snapshot_download(self.model_id)
model = BartForTextErrorCorrection(cache_path)
preprocessor = TextErrorCorrectionPreprocessor(cache_path)
pipeline1 = TextErrorCorrectionPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_error_correction,
model=model,
preprocessor=preprocessor)
print(
f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}'
)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = TextErrorCorrectionPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.text_error_correction,
model=model,
preprocessor=preprocessor)
print(pipeline_ins(self.input))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.text_error_correction, model=self.model_id)
print(pipeline_ins(self.input))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_error_correction)
print(pipeline_ins(self.input))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save