NLP新增文本纠错任务
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9540716
master
| @@ -24,6 +24,7 @@ class Models(object): | |||||
| translation = 'csanmt-translation' | translation = 'csanmt-translation' | ||||
| space = 'space' | space = 'space' | ||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| bart = 'bart' | |||||
| # audio models | # audio models | ||||
| sambert_hifigan = 'sambert-hifigan' | sambert_hifigan = 'sambert-hifigan' | ||||
| @@ -98,6 +99,7 @@ class Pipelines(object): | |||||
| dialog_modeling = 'dialog-modeling' | dialog_modeling = 'dialog-modeling' | ||||
| dialog_state_tracking = 'dialog-state-tracking' | dialog_state_tracking = 'dialog-state-tracking' | ||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| text_error_correction = 'text-error-correction' | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | sambert_hifigan_tts = 'sambert-hifigan-tts' | ||||
| @@ -161,6 +163,7 @@ class Preprocessors(object): | |||||
| dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' | dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor' | ||||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | ||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| text_error_correction = 'text-error-correction' | |||||
| # audio preprocessor | # audio preprocessor | ||||
| linear_aec_fbank = 'linear-aec-fbank' | linear_aec_fbank = 'linear-aec-fbank' | ||||
| @@ -22,6 +22,7 @@ if TYPE_CHECKING: | |||||
| from .space_for_dialog_modeling import SpaceForDialogModeling | from .space_for_dialog_modeling import SpaceForDialogModeling | ||||
| from .space_for_dialog_state_tracking import SpaceForDialogStateTracking | from .space_for_dialog_state_tracking import SpaceForDialogStateTracking | ||||
| from .task_model import SingleBackboneTaskModelBase | from .task_model import SingleBackboneTaskModelBase | ||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -46,6 +47,7 @@ else: | |||||
| 'space_for_dialog_modeling': ['SpaceForDialogModeling'], | 'space_for_dialog_modeling': ['SpaceForDialogModeling'], | ||||
| 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], | 'space_for_dialog_state_tracking': ['SpaceForDialogStateTracking'], | ||||
| 'task_model': ['SingleBackboneTaskModelBase'], | 'task_model': ['SingleBackboneTaskModelBase'], | ||||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,93 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import torch.cuda | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| __all__ = ['BartForTextErrorCorrection'] | |||||
| @MODELS.register_module(Tasks.text_error_correction, module_name=Models.bart) | |||||
| class BartForTextErrorCorrection(TorchModel): | |||||
| def __init__(self, model_dir, *args, **kwargs): | |||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||||
| """initialize the text error correction model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| ckpt_name = ModelFile.TORCH_MODEL_FILE | |||||
| local_model = osp.join(model_dir, ckpt_name) | |||||
| bart_vocab_dir = model_dir | |||||
| # turn on cuda if GPU is available | |||||
| from fairseq import checkpoint_utils, utils | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.use_fp16 = kwargs[ | |||||
| 'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\ | |||||
| else False | |||||
| overrides = { | |||||
| 'data': bart_vocab_dir, | |||||
| 'beam': 2, | |||||
| } | |||||
| models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |||||
| utils.split_paths(local_model), arg_overrides=overrides) | |||||
| # Move models to GPU | |||||
| for model in models: | |||||
| model.eval() | |||||
| model.to(self._device) | |||||
| if self.use_fp16: | |||||
| model.half() | |||||
| model.prepare_for_inference_(cfg) | |||||
| self.models = models | |||||
| # Initialize generator | |||||
| self.generator = task.build_generator(models, 'translation') | |||||
| self.task = task | |||||
| def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Example: | |||||
| 1 sent: | |||||
| {'net_input': | |||||
| {'src_tokens':tensor([2478,242,24,4]), | |||||
| 'src_lengths': tensor([4])} | |||||
| } | |||||
| Returns: | |||||
| Dict[str, Tensor]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer | |||||
| } | |||||
| """ | |||||
| import fairseq.utils | |||||
| if len(input['net_input']['src_tokens'].size()) == 1: | |||||
| input['net_input']['src_tokens'] = input['net_input'][ | |||||
| 'src_tokens'].view(1, -1) | |||||
| if torch.cuda.is_available(): | |||||
| input = fairseq.utils.move_to_cuda(input, device=self._device) | |||||
| sample = input | |||||
| translations = self.task.inference_step(self.generator, self.models, | |||||
| sample) | |||||
| # get 1-best List[Tensor] | |||||
| preds = translations[0][0]['tokens'] | |||||
| return {'predictions': preds} | |||||
| @@ -352,10 +352,15 @@ TASK_OUTPUTS = { | |||||
| # "text": "this is the text generated by a model." | # "text": "this is the text generated by a model." | ||||
| # } | # } | ||||
| Tasks.visual_question_answering: [OutputKeys.TEXT], | Tasks.visual_question_answering: [OutputKeys.TEXT], | ||||
| # auto_speech_recognition result for a single sample | # auto_speech_recognition result for a single sample | ||||
| # { | # { | ||||
| # "text": "每天都要快乐喔" | # "text": "每天都要快乐喔" | ||||
| # } | # } | ||||
| Tasks.auto_speech_recognition: [OutputKeys.TEXT] | |||||
| Tasks.auto_speech_recognition: [OutputKeys.TEXT], | |||||
| # text_error_correction result for a single sample | |||||
| # { | |||||
| # "output": "我想吃苹果" | |||||
| # } | |||||
| Tasks.text_error_correction: [OutputKeys.OUTPUT] | |||||
| } | } | ||||
| @@ -51,6 +51,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_space_dialog-modeling'), | 'damo/nlp_space_dialog-modeling'), | ||||
| Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking, | ||||
| 'damo/nlp_space_dialog-state-tracking'), | 'damo/nlp_space_dialog-state-tracking'), | ||||
| Tasks.text_error_correction: | |||||
| (Pipelines.text_error_correction, | |||||
| 'damo/nlp_bart_text-error-correction_chinese'), | |||||
| Tasks.image_captioning: (Pipelines.image_captioning, | Tasks.image_captioning: (Pipelines.image_captioning, | ||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_generation: | Tasks.image_generation: | ||||
| @@ -17,6 +17,7 @@ if TYPE_CHECKING: | |||||
| from .translation_pipeline import TranslationPipeline | from .translation_pipeline import TranslationPipeline | ||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -37,6 +38,7 @@ else: | |||||
| 'named_entity_recognition_pipeline': | 'named_entity_recognition_pipeline': | ||||
| ['NamedEntityRecognitionPipeline'], | ['NamedEntityRecognitionPipeline'], | ||||
| 'translation_pipeline': ['TranslationPipeline'], | 'translation_pipeline': ['TranslationPipeline'], | ||||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,69 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import BartForTextErrorCorrection | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import TextErrorCorrectionPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['TextErrorCorrectionPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text_error_correction, module_name=Pipelines.text_error_correction) | |||||
| class TextErrorCorrectionPipeline(Pipeline): | |||||
| def __init__( | |||||
| self, | |||||
| model: Union[BartForTextErrorCorrection, str], | |||||
| preprocessor: Optional[TextErrorCorrectionPreprocessor] = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text generation pipeline for prediction | |||||
| Args: | |||||
| model (BartForTextErrorCorrection): a model instance | |||||
| preprocessor (TextErrorCorrectionPreprocessor): a preprocessor instance | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, | |||||
| BartForTextErrorCorrection) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) | |||||
| self.vocab = preprocessor.vocab | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """ | |||||
| Args: | |||||
| inputs (Dict[str, Tensor]) | |||||
| Example: | |||||
| { | |||||
| 'predictions': Tensor([1377, 4959, 2785, 6392...]), # tokens need to be decode by tokenizer | |||||
| } | |||||
| Returns: | |||||
| Dict[str, str] | |||||
| Example: | |||||
| { | |||||
| 'output': '随着中国经济突飞猛进,建造工业与日俱增' | |||||
| } | |||||
| """ | |||||
| pred_str = self.vocab.string( | |||||
| inputs['predictions'], | |||||
| '@@', | |||||
| extra_symbols_to_ignore={self.vocab.pad()}) | |||||
| return {OutputKeys.OUTPUT: ''.join(pred_str.split())} | |||||
| @@ -21,7 +21,8 @@ if TYPE_CHECKING: | |||||
| TokenClassificationPreprocessor, NLIPreprocessor, | TokenClassificationPreprocessor, NLIPreprocessor, | ||||
| SentimentClassificationPreprocessor, | SentimentClassificationPreprocessor, | ||||
| SentenceSimilarityPreprocessor, FillMaskPreprocessor, | SentenceSimilarityPreprocessor, FillMaskPreprocessor, | ||||
| ZeroShotClassificationPreprocessor, NERPreprocessor) | |||||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||||
| TextErrorCorrectionPreprocessor) | |||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| @@ -49,7 +50,8 @@ else: | |||||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | ||||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | ||||
| 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', | 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' | |||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||||
| 'TextErrorCorrectionPreprocessor' | |||||
| ], | ], | ||||
| 'space': [ | 'space': [ | ||||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | |||||
| import uuid | import uuid | ||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| @@ -17,8 +18,9 @@ __all__ = [ | |||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | ||||
| 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | 'NLIPreprocessor', 'SentimentClassificationPreprocessor', | ||||
| 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', | |||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' | |||||
| 'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor', | |||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||||
| 'TextErrorCorrectionPreprocessor' | |||||
| ] | ] | ||||
| @@ -431,3 +433,42 @@ class NERPreprocessor(Preprocessor): | |||||
| 'label_mask': label_mask, | 'label_mask': label_mask, | ||||
| 'offset_mapping': offset_mapping | 'offset_mapping': offset_mapping | ||||
| } | } | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | |||||
| class TextErrorCorrectionPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| from fairseq.data import Dictionary | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| '随着中国经济突飞猛近,建造工业与日俱增' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| Example: | |||||
| {'net_input': | |||||
| {'src_tokens':tensor([1,2,3,4]), | |||||
| 'src_lengths': tensor([4])} | |||||
| } | |||||
| """ | |||||
| text = ' '.join([x for x in data]) | |||||
| inputs = self.vocab.encode_line( | |||||
| text, append_eos=True, add_if_not_exist=False) | |||||
| lengths = inputs.size() | |||||
| sample = dict() | |||||
| sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||||
| return sample | |||||
| @@ -67,6 +67,7 @@ class NLPTasks(object): | |||||
| question_answering = 'question-answering' | question_answering = 'question-answering' | ||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | |||||
| class AudioTasks(object): | class AudioTasks(object): | ||||
| @@ -1,4 +1,5 @@ | |||||
| en_core_web_sm>=2.3.5 | en_core_web_sm>=2.3.5 | ||||
| fairseq>=0.10.2 | |||||
| pai-easynlp | pai-easynlp | ||||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | # rough-score was just recently updated from 0.0.4 to 0.0.7 | ||||
| # which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import BartForTextErrorCorrection | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import TextErrorCorrectionPipeline | |||||
| from modelscope.preprocessors import TextErrorCorrectionPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TextErrorCorrectionTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_bart_text-error-correction_chinese' | |||||
| input = '随着中国经济突飞猛近,建造工业与日俱增' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_direct_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = BartForTextErrorCorrection(cache_path) | |||||
| preprocessor = TextErrorCorrectionPreprocessor(cache_path) | |||||
| pipeline1 = TextErrorCorrectionPipeline(model, preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.text_error_correction, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| print( | |||||
| f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' | |||||
| ) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text_error_correction, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| print(pipeline_ins(self.input)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.text_error_correction, model=self.model_id) | |||||
| print(pipeline_ins(self.input)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.text_error_correction) | |||||
| print(pipeline_ins(self.input)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||