Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9938140master
| @@ -69,6 +69,7 @@ class Models(object): | |||
| class TaskModels(object): | |||
| # nlp task | |||
| text_classification = 'text-classification' | |||
| information_extraction = 'information-extraction' | |||
| class Heads(object): | |||
| @@ -78,6 +79,7 @@ class Heads(object): | |||
| bert_mlm = 'bert-mlm' | |||
| # roberta mlm | |||
| roberta_mlm = 'roberta-mlm' | |||
| information_extraction = 'information-extraction' | |||
| class Pipelines(object): | |||
| @@ -156,6 +158,7 @@ class Pipelines(object): | |||
| text_error_correction = 'text-error-correction' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| # audio tasks | |||
| @@ -248,6 +251,7 @@ class Preprocessors(object): | |||
| fill_mask = 'fill-mask' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| # audio preprocessor | |||
| @@ -21,7 +21,9 @@ if TYPE_CHECKING: | |||
| from .space import SpaceForDialogModeling | |||
| from .space import SpaceForDialogStateTracking | |||
| from .star_text_to_sql import StarForTextToSql | |||
| from .task_models.task_model import SingleBackboneTaskModelBase | |||
| from .task_models import (InformationExtractionModel, | |||
| SequenceClassificationModel, | |||
| SingleBackboneTaskModelBase) | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
| @@ -48,10 +50,13 @@ else: | |||
| 'SpaceForDialogIntent', 'SpaceForDialogModeling', | |||
| 'SpaceForDialogStateTracking' | |||
| ], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'task_models': [ | |||
| 'InformationExtractionModel', 'SequenceClassificationModel', | |||
| 'SingleBackboneTaskModelBase' | |||
| ], | |||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
| 'gpt3': ['GPT3ForTextGeneration'], | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,106 @@ | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from modelscope.metainfo import Heads | |||
| from modelscope.models.base import TorchHead | |||
| from modelscope.models.builder import HEADS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| @HEADS.register_module( | |||
| Tasks.information_extraction, module_name=Heads.information_extraction) | |||
| class InformationExtractionHead(TorchHead): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| config = self.config | |||
| assert config.get('labels') is not None | |||
| self.labels = config.labels | |||
| self.s_layer = nn.Linear(config.hidden_size, 2) # head, tail, bce | |||
| self.o_layer = nn.Linear(2 * config.hidden_size, 2) # head, tail, bce | |||
| self.p_layer = nn.Linear(config.hidden_size, | |||
| len(self.labels)) # label, ce | |||
| self.mha = nn.MultiheadAttention(config.hidden_size, 4) | |||
| def forward(self, sequence_output, text, offsets, threshold=0.5): | |||
| # assert batch size == 1 | |||
| spos = [] | |||
| s_head_logits, s_tail_logits = self.s_layer(sequence_output).split( | |||
| 1, dim=-1) # (b, seq_len, 2) | |||
| s_head_logits = s_head_logits[0, :, 0].sigmoid() # (seq_len) | |||
| s_tail_logits = s_tail_logits[0, :, 0].sigmoid() # (seq_len) | |||
| s_masks, subjects = self._get_masks_and_mentions( | |||
| text, offsets, s_head_logits, s_tail_logits, None, threshold) | |||
| for s_mask, subject in zip(s_masks, subjects): | |||
| masked_sequence_output = sequence_output * s_mask.unsqueeze( | |||
| 0).unsqueeze(-1) # (b, s, h) | |||
| subjected_sequence_output = self.mha( | |||
| sequence_output.permute(1, 0, 2), | |||
| masked_sequence_output.permute(1, 0, 2), | |||
| masked_sequence_output.permute(1, 0, | |||
| 2))[0].permute(1, 0, | |||
| 2) # (b, s, h) | |||
| cat_sequence_output = torch.cat( | |||
| (sequence_output, subjected_sequence_output), dim=-1) | |||
| o_head_logits, o_tail_logits = self.o_layer( | |||
| cat_sequence_output).split( | |||
| 1, dim=-1) | |||
| o_head_logits = o_head_logits[0, :, 0].sigmoid() # (seq_len) | |||
| o_tail_logits = o_tail_logits[0, :, 0].sigmoid() # (seq_len) | |||
| so_masks, objects = self._get_masks_and_mentions( | |||
| text, offsets, o_head_logits, o_tail_logits, s_mask, threshold) | |||
| for so_mask, object in zip(so_masks, objects): | |||
| masked_sequence_output = ( | |||
| sequence_output * so_mask.unsqueeze(0).unsqueeze(-1)).sum( | |||
| 1) # (b, h) | |||
| lengths = so_mask.unsqueeze(0).sum(-1, keepdim=True) # (b, 1) | |||
| pooled_subject_object = masked_sequence_output / lengths # (b, h) | |||
| label = self.p_layer(pooled_subject_object).sigmoid().squeeze( | |||
| 0) | |||
| for i in range(label.size(-1)): | |||
| if label[i] > threshold: | |||
| predicate = self.labels[i] | |||
| spos.append((subject, predicate, object)) | |||
| return spos | |||
| def _get_masks_and_mentions(self, | |||
| text, | |||
| offsets, | |||
| heads, | |||
| tails, | |||
| init_mask=None, | |||
| threshold=0.5): | |||
| ''' | |||
| text: str | |||
| heads: tensor (len(heads)) | |||
| tails: tensor (len(tails)) | |||
| ''' | |||
| seq_len = heads.size(-1) | |||
| potential_heads = [] | |||
| for i in range(seq_len - 1): | |||
| if heads[i] > threshold: | |||
| potential_heads.append(i) | |||
| potential_heads.append(seq_len - 1) | |||
| masks = [] | |||
| mentions = [] | |||
| for i in range(len(potential_heads) - 1): | |||
| head_index = potential_heads[i] | |||
| tail_index, max_val = None, 0 | |||
| for j in range(head_index, potential_heads[i + 1]): | |||
| if tails[j] > max_val and tails[j] > threshold: | |||
| tail_index = j | |||
| max_val = tails[j] | |||
| if tail_index is not None: | |||
| mask = torch.zeros_like( | |||
| heads) if init_mask is None else init_mask.clone() | |||
| mask[head_index:tail_index + 1] = 1 | |||
| masks.append(mask) # (seq_len) | |||
| char_head = offsets[head_index][0] | |||
| char_tail = offsets[tail_index][1] | |||
| mention = text[char_head:char_tail] | |||
| mentions.append(mention) | |||
| return masks, mentions | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .information_extraction import InformationExtractionModel | |||
| from .sequence_classification import SequenceClassificationModel | |||
| from .task_model import SingleBackboneTaskModelBase | |||
| else: | |||
| _import_structure = { | |||
| 'information_extraction': ['InformationExtractionModel'], | |||
| 'sequence_classification': ['SequenceClassificationModel'], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,49 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| __all__ = ['InformationExtractionModel'] | |||
| @MODELS.register_module( | |||
| Tasks.information_extraction, | |||
| module_name=TaskModels.information_extraction) | |||
| class InformationExtractionModel(SingleBackboneTaskModelBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the information extraction model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| backbone_cfg = self.cfg.backbone | |||
| head_cfg = self.cfg.head | |||
| self.build_backbone(backbone_cfg) | |||
| self.build_head(head_cfg) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| outputs = super().forward(input) | |||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
| outputs = self.head.forward(sequence_output, input['text'], | |||
| input['offsets']) | |||
| return {OutputKeys.SPO_LIST: outputs} | |||
| def extract_backbone_outputs(self, outputs): | |||
| sequence_output = None | |||
| pooled_output = None | |||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
| return sequence_output, pooled_output | |||
| @@ -302,8 +302,7 @@ TASK_OUTPUTS = { | |||
| # "text": "《父老乡亲》是由是由由中国人民解放军海政文工团创作的军旅歌曲,石顺义作词,王锡仁作曲,范琳琳演唱", | |||
| # "spo_list": [{"subject": "石顺义", "predicate": "国籍", "object": "中国"}] | |||
| # } | |||
| Tasks.relation_extraction: | |||
| [OutputKeys.UUID, OutputKeys.TEXT, OutputKeys.SPO_LIST], | |||
| Tasks.relation_extraction: [OutputKeys.SPO_LIST], | |||
| # translation result for a source sentence | |||
| # { | |||
| @@ -23,6 +23,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.named_entity_recognition: | |||
| (Pipelines.named_entity_recognition, | |||
| 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | |||
| Tasks.information_extraction: | |||
| (Pipelines.relation_extraction, | |||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | |||
| Tasks.sentence_similarity: | |||
| (Pipelines.sentence_similarity, | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | |||
| from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | |||
| @@ -22,6 +23,7 @@ if TYPE_CHECKING: | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| from .relation_extraction_pipeline import RelationExtractionPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -33,6 +35,7 @@ else: | |||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | |||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | |||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'single_sentence_classification_pipeline': | |||
| ['SingleSentenceClassificationPipeline'], | |||
| 'pair_sentence_classification_pipeline': | |||
| @@ -48,7 +51,8 @@ else: | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
| 'relation_extraction_pipeline': ['RelationExtractionPipeline'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,42 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| RelationExtractionPreprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['InformationExtractionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.information_extraction, module_name=Pipelines.relation_extraction) | |||
| class InformationExtractionPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = RelationExtractionPreprocessor( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| return inputs | |||
| @@ -22,7 +22,8 @@ if TYPE_CHECKING: | |||
| PairSentenceClassificationPreprocessor, | |||
| FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | |||
| NERPreprocessor, TextErrorCorrectionPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor) | |||
| FaqQuestionAnsweringPreprocessor, | |||
| RelationExtractionPreprocessor) | |||
| from .slp import DocumentSegmentationPreprocessor | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| @@ -51,7 +52,8 @@ else: | |||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor' | |||
| 'FaqQuestionAnsweringPreprocessor', | |||
| 'RelationExtractionPreprocessor' | |||
| ], | |||
| 'slp': ['DocumentSegmentationPreprocessor'], | |||
| 'space': [ | |||
| @@ -22,7 +22,8 @@ __all__ = [ | |||
| 'PairSentenceClassificationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' | |||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||
| 'RelationExtractionPreprocessor' | |||
| ] | |||
| @@ -622,6 +623,52 @@ class NERPreprocessor(Preprocessor): | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.re_tokenizer) | |||
| class RelationExtractionPreprocessor(Preprocessor): | |||
| """The tokenizer preprocessor used in normal RE task. | |||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # preprocess the data for the model input | |||
| text = data | |||
| output = self.tokenizer([text], return_tensors='pt') | |||
| return { | |||
| 'text': text, | |||
| 'input_ids': output['input_ids'], | |||
| 'attention_mask': output['attention_mask'], | |||
| 'offsets': output[0].offsets | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | |||
| class TextErrorCorrectionPreprocessor(Preprocessor): | |||
| @@ -99,6 +99,7 @@ class NLPTasks(object): | |||
| text_error_correction = 'text-error-correction' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| information_extraction = 'information-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import torch | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import InformationExtractionModel | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import InformationExtractionPipeline | |||
| from modelscope.preprocessors import RelationExtractionPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class RelationExtractionTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_bert_relation-extraction_chinese-base' | |||
| sentence = '高捷,祖籍江苏,本科毕业于东南大学' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = RelationExtractionPreprocessor(cache_path) | |||
| model = InformationExtractionModel.from_pretrained(cache_path) | |||
| pipeline1 = InformationExtractionPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.information_extraction, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = RelationExtractionPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.information_extraction, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.information_extraction, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.information_extraction) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||