From ca4b5b2565ddc6a8de72d9769d02c54867cf5346 Mon Sep 17 00:00:00 2001 From: "klayzhang.zb" Date: Thu, 28 Jul 2022 23:01:28 +0800 Subject: [PATCH] [to #42322933][NLP] Add text error correction task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NLP新增文本纠错任务 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9540716 --- modelscope/metainfo.py | 3 + modelscope/models/nlp/__init__.py | 2 + .../nlp/bart_for_text_error_correction.py | 93 +++++++++++++++++++ modelscope/outputs.py | 9 +- modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 2 + .../nlp/text_error_correction_pipeline.py | 69 ++++++++++++++ modelscope/preprocessors/__init__.py | 6 +- modelscope/preprocessors/nlp.py | 45 ++++++++- modelscope/utils/constant.py | 1 + requirements/nlp.txt | 1 + tests/pipelines/test_text_error_correction.py | 55 +++++++++++ 12 files changed, 283 insertions(+), 6 deletions(-) create mode 100644 modelscope/models/nlp/bart_for_text_error_correction.py create mode 100644 modelscope/pipelines/nlp/text_error_correction_pipeline.py create mode 100644 tests/pipelines/test_text_error_correction.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 5efc724c..8ea9e7ed 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -24,6 +24,7 @@ class Models(object): translation = 'csanmt-translation' space = 'space' tcrf = 'transformer-crf' + bart = 'bart' # audio models sambert_hifigan = 'sambert-hifigan' @@ -98,6 +99,7 @@ class Pipelines(object): dialog_modeling = 'dialog-modeling' dialog_state_tracking = 'dialog-state-tracking' zero_shot_classification = 'zero-shot-classification' + text_error_correction = 'text-error-correction' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -161,6 +163,7 @@ class Preprocessors(object): dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' + text_error_correction = 'text-error-correction' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 52cffb0c..23041168 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .space_for_dialog_modeling import SpaceForDialogModeling from .space_for_dialog_state_tracking import SpaceForDialogStateTracking from .task_model import SingleBackboneTaskModelBase + from .bart_for_text_error_correction import BartForTextErrorCorrection else: _import_structure = { @@ -46,6 +47,7 @@ else: 'space_for_dialog_modeling': ['SpaceForDialogModeling'], 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], 'task_model': ['SingleBackboneTaskModelBase'], + 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], } import sys diff --git a/modelscope/models/nlp/bart_for_text_error_correction.py b/modelscope/models/nlp/bart_for_text_error_correction.py new file mode 100644 index 00000000..2339f221 --- /dev/null +++ b/modelscope/models/nlp/bart_for_text_error_correction.py @@ -0,0 +1,93 @@ +import os.path as osp +from typing import Any, Dict + +import torch.cuda + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['BartForTextErrorCorrection'] + + +@MODELS.register_module(Tasks.text_error_correction, module_name=Models.bart) +class BartForTextErrorCorrection(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + """initialize the text error correction model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + ckpt_name = ModelFile.TORCH_MODEL_FILE + local_model = osp.join(model_dir, ckpt_name) + + bart_vocab_dir = model_dir + # turn on cuda if GPU is available + from fairseq import checkpoint_utils, utils + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.use_fp16 = kwargs[ + 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ + else False + + overrides = { + 'data': bart_vocab_dir, + 'beam': 2, + } + models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(local_model), arg_overrides=overrides) + # Move models to GPU + for model in models: + model.eval() + model.to(self._device) + if self.use_fp16: + model.half() + model.prepare_for_inference_(cfg) + self.models = models + # Initialize generator + self.generator = task.build_generator(models, 'translation') + + self.task = task + + def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + Example: + 1 sent: + {'net_input': + {'src_tokens':tensor([2478,242,24,4]), + 'src_lengths': tensor([4])} + } + + + Returns: + Dict[str, Tensor]: results + Example: + { + 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer + } + """ + import fairseq.utils + + if len(input['net_input']['src_tokens'].size()) == 1: + input['net_input']['src_tokens'] = input['net_input'][ + 'src_tokens'].view(1, -1) + + if torch.cuda.is_available(): + input = fairseq.utils.move_to_cuda(input, device=self._device) + + sample = input + + translations = self.task.inference_step(self.generator, self.models, + sample) + + # get 1-best List[Tensor] + preds = translations[0][0]['tokens'] + return {'predictions': preds} diff --git a/modelscope/outputs.py b/modelscope/outputs.py index cffbc05f..dee31a4f 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -352,10 +352,15 @@ TASK_OUTPUTS = { # "text": "this is the text generated by a model." # } Tasks.visual_question_answering: [OutputKeys.TEXT], - # auto_speech_recognition result for a single sample # { # "text": "每天都要快乐喔" # } - Tasks.auto_speech_recognition: [OutputKeys.TEXT] + Tasks.auto_speech_recognition: [OutputKeys.TEXT], + + # text_error_correction result for a single sample + # { + # "output": "我想吃苹果" + # } + Tasks.text_error_correction: [OutputKeys.OUTPUT] } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index cf8b1147..15a367b9 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -51,6 +51,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_space_dialog-modeling'), Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, 'damo/nlp_space_dialog-state-tracking'), + Tasks.text_error_correction: + (Pipelines.text_error_correction, + 'damo/nlp_bart_text-error-correction_chinese'), Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 6b3ca000..561ced1a 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline + from .text_error_correction_pipeline import TextErrorCorrectionPipeline else: _import_structure = { @@ -37,6 +38,7 @@ else: 'named_entity_recognition_pipeline': ['NamedEntityRecognitionPipeline'], 'translation_pipeline': ['TranslationPipeline'], + 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] } import sys diff --git a/modelscope/pipelines/nlp/text_error_correction_pipeline.py b/modelscope/pipelines/nlp/text_error_correction_pipeline.py new file mode 100644 index 00000000..44fae08f --- /dev/null +++ b/modelscope/pipelines/nlp/text_error_correction_pipeline.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import BartForTextErrorCorrection +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextErrorCorrectionPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['TextErrorCorrectionPipeline'] + + +@PIPELINES.register_module( + Tasks.text_error_correction, module_name=Pipelines.text_error_correction) +class TextErrorCorrectionPipeline(Pipeline): + + def __init__( + self, + model: Union[BartForTextErrorCorrection, str], + preprocessor: Optional[TextErrorCorrectionPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text generation pipeline for prediction + + Args: + model (BartForTextErrorCorrection): a model instance + preprocessor (TextErrorCorrectionPreprocessor): a preprocessor instance + """ + model = model if isinstance( + model, + BartForTextErrorCorrection) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) + self.vocab = preprocessor.vocab + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """ + Args: + inputs (Dict[str, Tensor]) + Example: + { + 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer + } + Returns: + Dict[str, str] + Example: + { + 'output': '随着中国经济突飞猛进,建造工业与日俱增' + } + + + """ + + pred_str = self.vocab.string( + inputs['predictions'], + '@@', + extra_symbols_to_ignore={self.vocab.pad()}) + + return {OutputKeys.OUTPUT: ''.join(pred_str.split())} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 1aba9107..38fe3b9a 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -21,7 +21,8 @@ if TYPE_CHECKING: TokenClassificationPreprocessor, NLIPreprocessor, SentimentClassificationPreprocessor, SentenceSimilarityPreprocessor, FillMaskPreprocessor, - ZeroShotClassificationPreprocessor, NERPreprocessor) + ZeroShotClassificationPreprocessor, NERPreprocessor, + TextErrorCorrectionPreprocessor) from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor) @@ -49,7 +50,8 @@ else: 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', 'NLIPreprocessor', 'SentimentClassificationPreprocessor', 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', - 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' + 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', + 'TextErrorCorrectionPreprocessor' ], 'space': [ 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index cd170fc1..0da17cb0 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp import uuid from typing import Any, Dict, Union @@ -17,8 +18,9 @@ __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', 'NLIPreprocessor', 'SentimentClassificationPreprocessor', - 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', - 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' + 'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor', + 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', + 'TextErrorCorrectionPreprocessor' ] @@ -431,3 +433,42 @@ class NERPreprocessor(Preprocessor): 'label_mask': label_mask, 'offset_mapping': offset_mapping } + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.text_error_correction) +class TextErrorCorrectionPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + from fairseq.data import Dictionary + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) + + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + '随着中国经济突飞猛近,建造工业与日俱增' + Returns: + Dict[str, Any]: the preprocessed data + Example: + {'net_input': + {'src_tokens':tensor([1,2,3,4]), + 'src_lengths': tensor([4])} + } + """ + + text = ' '.join([x for x in data]) + inputs = self.vocab.encode_line( + text, append_eos=True, add_if_not_exist=False) + lengths = inputs.size() + sample = dict() + sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} + return sample diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index eececd8d..4bb6ba5d 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -67,6 +67,7 @@ class NLPTasks(object): question_answering = 'question-answering' zero_shot_classification = 'zero-shot-classification' backbone = 'backbone' + text_error_correction = 'text-error-correction' class AudioTasks(object): diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 85e0bbb7..deb6a5bd 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,4 +1,5 @@ en_core_web_sm>=2.3.5 +fairseq>=0.10.2 pai-easynlp # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated diff --git a/tests/pipelines/test_text_error_correction.py b/tests/pipelines/test_text_error_correction.py new file mode 100644 index 00000000..0ccc003c --- /dev/null +++ b/tests/pipelines/test_text_error_correction.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import BartForTextErrorCorrection +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextErrorCorrectionPipeline +from modelscope.preprocessors import TextErrorCorrectionPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextErrorCorrectionTest(unittest.TestCase): + model_id = 'damo/nlp_bart_text-error-correction_chinese' + input = '随着中国经济突飞猛近,建造工业与日俱增' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_download(self): + cache_path = snapshot_download(self.model_id) + model = BartForTextErrorCorrection(cache_path) + preprocessor = TextErrorCorrectionPreprocessor(cache_path) + pipeline1 = TextErrorCorrectionPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_error_correction, + model=model, + preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' + ) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.text_error_correction, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(self.input)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_error_correction, model=self.model_id) + print(pipeline_ins(self.input)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_error_correction) + print(pipeline_ins(self.input)) + + +if __name__ == '__main__': + unittest.main()