NLP新增文本纠错任务
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9540716
master
| @@ -24,6 +24,7 @@ class Models(object): | |||
| translation = 'csanmt-translation' | |||
| space = 'space' | |||
| tcrf = 'transformer-crf' | |||
| bart = 'bart' | |||
| # audio models | |||
| sambert_hifigan = 'sambert-hifigan' | |||
| @@ -98,6 +99,7 @@ class Pipelines(object): | |||
| dialog_modeling = 'dialog-modeling' | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -161,6 +163,7 @@ class Preprocessors(object): | |||
| dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| text_error_correction = 'text-error-correction' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -22,6 +22,7 @@ if TYPE_CHECKING: | |||
| from .space_for_dialog_modeling import SpaceForDialogModeling | |||
| from .space_for_dialog_state_tracking import SpaceForDialogStateTracking | |||
| from .task_model import SingleBackboneTaskModelBase | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| else: | |||
| _import_structure = { | |||
| @@ -46,6 +47,7 @@ else: | |||
| 'space_for_dialog_modeling': ['SpaceForDialogModeling'], | |||
| 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,93 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import torch.cuda | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['BartForTextErrorCorrection'] | |||
| @MODELS.register_module(Tasks.text_error_correction, module_name=Models.bart) | |||
| class BartForTextErrorCorrection(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| """initialize the text error correction model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| ckpt_name = ModelFile.TORCH_MODEL_FILE | |||
| local_model = osp.join(model_dir, ckpt_name) | |||
| bart_vocab_dir = model_dir | |||
| # turn on cuda if GPU is available | |||
| from fairseq import checkpoint_utils, utils | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.use_fp16 = kwargs[ | |||
| 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ | |||
| else False | |||
| overrides = { | |||
| 'data': bart_vocab_dir, | |||
| 'beam': 2, | |||
| } | |||
| models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |||
| utils.split_paths(local_model), arg_overrides=overrides) | |||
| # Move models to GPU | |||
| for model in models: | |||
| model.eval() | |||
| model.to(self._device) | |||
| if self.use_fp16: | |||
| model.half() | |||
| model.prepare_for_inference_(cfg) | |||
| self.models = models | |||
| # Initialize generator | |||
| self.generator = task.build_generator(models, 'translation') | |||
| self.task = task | |||
| def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Example: | |||
| 1 sent: | |||
| {'net_input': | |||
| {'src_tokens':tensor([2478,242,24,4]), | |||
| 'src_lengths': tensor([4])} | |||
| } | |||
| Returns: | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer | |||
| } | |||
| """ | |||
| import fairseq.utils | |||
| if len(input['net_input']['src_tokens'].size()) == 1: | |||
| input['net_input']['src_tokens'] = input['net_input'][ | |||
| 'src_tokens'].view(1, -1) | |||
| if torch.cuda.is_available(): | |||
| input = fairseq.utils.move_to_cuda(input, device=self._device) | |||
| sample = input | |||
| translations = self.task.inference_step(self.generator, self.models, | |||
| sample) | |||
| # get 1-best List[Tensor] | |||
| preds = translations[0][0]['tokens'] | |||
| return {'predictions': preds} | |||
| @@ -352,10 +352,15 @@ TASK_OUTPUTS = { | |||
| # "text": "this is the text generated by a model." | |||
| # } | |||
| Tasks.visual_question_answering: [OutputKeys.TEXT], | |||
| # auto_speech_recognition result for a single sample | |||
| # { | |||
| # "text": "每天都要快乐喔" | |||
| # } | |||
| Tasks.auto_speech_recognition: [OutputKeys.TEXT] | |||
| Tasks.auto_speech_recognition: [OutputKeys.TEXT], | |||
| # text_error_correction result for a single sample | |||
| # { | |||
| # "output": "我想吃苹果" | |||
| # } | |||
| Tasks.text_error_correction: [OutputKeys.OUTPUT] | |||
| } | |||
| @@ -51,6 +51,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/nlp_space_dialog-modeling'), | |||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | |||
| 'damo/nlp_space_dialog-state-tracking'), | |||
| Tasks.text_error_correction: | |||
| (Pipelines.text_error_correction, | |||
| 'damo/nlp_bart_text-error-correction_chinese'), | |||
| Tasks.image_captioning: (Pipelines.image_captioning, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| @@ -17,6 +17,7 @@ if TYPE_CHECKING: | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -37,6 +38,7 @@ else: | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'translation_pipeline': ['TranslationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,69 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import BartForTextErrorCorrection | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import TextErrorCorrectionPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['TextErrorCorrectionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text_error_correction, module_name=Pipelines.text_error_correction) | |||
| class TextErrorCorrectionPipeline(Pipeline): | |||
| def __init__( | |||
| self, | |||
| model: Union[BartForTextErrorCorrection, str], | |||
| preprocessor: Optional[TextErrorCorrectionPreprocessor] = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text generation pipeline for prediction | |||
| Args: | |||
| model (BartForTextErrorCorrection): a model instance | |||
| preprocessor (TextErrorCorrectionPreprocessor): a preprocessor instance | |||
| """ | |||
| model = model if isinstance( | |||
| model, | |||
| BartForTextErrorCorrection) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) | |||
| self.vocab = preprocessor.vocab | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """ | |||
| Args: | |||
| inputs (Dict[str, Tensor]) | |||
| Example: | |||
| { | |||
| 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer | |||
| } | |||
| Returns: | |||
| Dict[str, str] | |||
| Example: | |||
| { | |||
| 'output': '随着中国经济突飞猛进,建造工业与日俱增' | |||
| } | |||
| """ | |||
| pred_str = self.vocab.string( | |||
| inputs['predictions'], | |||
| '@@', | |||
| extra_symbols_to_ignore={self.vocab.pad()}) | |||
| return {OutputKeys.OUTPUT: ''.join(pred_str.split())} | |||
| @@ -21,7 +21,8 @@ if TYPE_CHECKING: | |||
| TokenClassificationPreprocessor, NLIPreprocessor, | |||
| SentimentClassificationPreprocessor, | |||
| SentenceSimilarityPreprocessor, FillMaskPreprocessor, | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor) | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| TextErrorCorrectionPreprocessor) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| @@ -49,7 +50,8 @@ else: | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | |||
| 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor' | |||
| ], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| @@ -17,8 +18,9 @@ __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | |||
| 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' | |||
| 'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor' | |||
| ] | |||
| @@ -431,3 +433,42 @@ class NERPreprocessor(Preprocessor): | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | |||
| class TextErrorCorrectionPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| from fairseq.data import Dictionary | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| '随着中国经济突飞猛近,建造工业与日俱增' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| Example: | |||
| {'net_input': | |||
| {'src_tokens':tensor([1,2,3,4]), | |||
| 'src_lengths': tensor([4])} | |||
| } | |||
| """ | |||
| text = ' '.join([x for x in data]) | |||
| inputs = self.vocab.encode_line( | |||
| text, append_eos=True, add_if_not_exist=False) | |||
| lengths = inputs.size() | |||
| sample = dict() | |||
| sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||
| return sample | |||
| @@ -67,6 +67,7 @@ class NLPTasks(object): | |||
| question_answering = 'question-answering' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| backbone = 'backbone' | |||
| text_error_correction = 'text-error-correction' | |||
| class AudioTasks(object): | |||
| @@ -1,4 +1,5 @@ | |||
| en_core_web_sm>=2.3.5 | |||
| fairseq>=0.10.2 | |||
| pai-easynlp | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import BartForTextErrorCorrection | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import TextErrorCorrectionPipeline | |||
| from modelscope.preprocessors import TextErrorCorrectionPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TextErrorCorrectionTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_bart_text-error-correction_chinese' | |||
| input = '随着中国经济突飞猛近,建造工业与日俱增' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = BartForTextErrorCorrection(cache_path) | |||
| preprocessor = TextErrorCorrectionPreprocessor(cache_path) | |||
| pipeline1 = TextErrorCorrectionPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.text_error_correction, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print( | |||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_error_correction, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_ins(self.input)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_error_correction, model=self.model_id) | |||
| print(pipeline_ins(self.input)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.text_error_correction) | |||
| print(pipeline_ins(self.input)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||