| @@ -24,7 +24,8 @@ if TYPE_CHECKING: | |||||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | ||||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | ||||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | ||||
| PassageRankingPreprocessor) | |||||
| PassageRankingPreprocessor, | |||||
| WordSegmentationBlankSetToLabelPreprocessor) | |||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
| @@ -56,6 +57,7 @@ else: | |||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | ||||
| 'RelationExtractionPreprocessor', | 'RelationExtractionPreprocessor', | ||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | ||||
| ], | ], | ||||
| 'space': [ | 'space': [ | ||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .text_error_correction import TextErrorCorrectionPreprocessor | |||||
| from .nlp_base import ( | |||||
| Tokenize, SequenceClassificationPreprocessor, | |||||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||||
| SingleSentenceClassificationPreprocessor, | |||||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||||
| FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, | |||||
| RelationExtractionPreprocessor, DocumentSegmentationPreprocessor, | |||||
| FillMaskPoNetPreprocessor, PassageRankingPreprocessor, | |||||
| WordSegmentationBlankSetToLabelPreprocessor) | |||||
| else: | |||||
| _import_structure = { | |||||
| 'nlp_base': [ | |||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||||
| 'SingleSentenceClassificationPreprocessor', | |||||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||||
| 'RelationExtractionPreprocessor', | |||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||||
| ], | |||||
| 'text_error_correction': [ | |||||
| 'TextErrorCorrectionPreprocessor', | |||||
| ], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -6,20 +6,19 @@ import uuid | |||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | from typing import Any, Dict, Iterable, Optional, Tuple, Union | ||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| from transformers import AutoTokenizer, BertTokenizerFast | from transformers import AutoTokenizer, BertTokenizerFast | ||||
| from modelscope.metainfo import Models, Preprocessors | from modelscope.metainfo import Models, Preprocessors | ||||
| from modelscope.models.nlp.structbert import SbertTokenizerFast | from modelscope.models.nlp.structbert import SbertTokenizerFast | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.utils.config import Config, ConfigFields | from modelscope.utils.config import Config, ConfigFields | ||||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | ||||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | from modelscope.utils.hub import get_model_type, parse_label_mapping | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.nlp import import_external_nltk_data | from modelscope.utils.nlp import import_external_nltk_data | ||||
| from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -30,9 +29,9 @@ __all__ = [ | |||||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | ||||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', | |||||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||||
| 'RelationExtractionPreprocessor', 'DocumentSegmentationPreprocessor', | |||||
| 'FillMaskPoNetPreprocessor' | |||||
| ] | ] | ||||
| @@ -889,47 +888,6 @@ class RelationExtractionPreprocessor(Preprocessor): | |||||
| } | } | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | |||||
| class TextErrorCorrectionPreprocessor(Preprocessor): | |||||
| """The preprocessor used in text correction task. | |||||
| """ | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| from fairseq.data import Dictionary | |||||
| """preprocess the data via the vocab file from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| '随着中国经济突飞猛近,建造工业与日俱增' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| Example: | |||||
| {'net_input': | |||||
| {'src_tokens':tensor([1,2,3,4]), | |||||
| 'src_lengths': tensor([4])} | |||||
| } | |||||
| """ | |||||
| text = ' '.join([x for x in data]) | |||||
| inputs = self.vocab.encode_line( | |||||
| text, append_eos=True, add_if_not_exist=False) | |||||
| lengths = inputs.size() | |||||
| sample = dict() | |||||
| sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||||
| return sample | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) | Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) | ||||
| class FaqQuestionAnsweringPreprocessor(Preprocessor): | class FaqQuestionAnsweringPreprocessor(Preprocessor): | ||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.utils.constant import Fields | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | |||||
| class TextErrorCorrectionPreprocessor(Preprocessor): | |||||
| """The preprocessor used in text correction task. | |||||
| """ | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| from fairseq.data import Dictionary | |||||
| """preprocess the data via the vocab file from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| '随着中国经济突飞猛近,建造工业与日俱增' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| Example: | |||||
| {'net_input': | |||||
| {'src_tokens':tensor([1,2,3,4]), | |||||
| 'src_lengths': tensor([4])} | |||||
| } | |||||
| """ | |||||
| text = ' '.join([x for x in data]) | |||||
| inputs = self.vocab.encode_line( | |||||
| text, append_eos=True, add_if_not_exist=False) | |||||
| lengths = inputs.size() | |||||
| sample = dict() | |||||
| sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||||
| return sample | |||||