Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9938140master
| @@ -69,6 +69,7 @@ class Models(object): | |||||
| class TaskModels(object): | class TaskModels(object): | ||||
| # nlp task | # nlp task | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| information_extraction = 'information-extraction' | |||||
| class Heads(object): | class Heads(object): | ||||
| @@ -78,6 +79,7 @@ class Heads(object): | |||||
| bert_mlm = 'bert-mlm' | bert_mlm = 'bert-mlm' | ||||
| # roberta mlm | # roberta mlm | ||||
| roberta_mlm = 'roberta-mlm' | roberta_mlm = 'roberta-mlm' | ||||
| information_extraction = 'information-extraction' | |||||
| class Pipelines(object): | class Pipelines(object): | ||||
| @@ -156,6 +158,7 @@ class Pipelines(object): | |||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| relation_extraction = 'relation-extraction' | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| # audio tasks | # audio tasks | ||||
| @@ -248,6 +251,7 @@ class Preprocessors(object): | |||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| re_tokenizer = 're-tokenizer' | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| # audio preprocessor | # audio preprocessor | ||||
| @@ -21,7 +21,9 @@ if TYPE_CHECKING: | |||||
| from .space import SpaceForDialogModeling | from .space import SpaceForDialogModeling | ||||
| from .space import SpaceForDialogStateTracking | from .space import SpaceForDialogStateTracking | ||||
| from .star_text_to_sql import StarForTextToSql | from .star_text_to_sql import StarForTextToSql | ||||
| from .task_models.task_model import SingleBackboneTaskModelBase | |||||
| from .task_models import (InformationExtractionModel, | |||||
| SequenceClassificationModel, | |||||
| SingleBackboneTaskModelBase) | |||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | from .bart_for_text_error_correction import BartForTextErrorCorrection | ||||
| from .gpt3 import GPT3ForTextGeneration | from .gpt3 import GPT3ForTextGeneration | ||||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | ||||
| @@ -48,10 +50,13 @@ else: | |||||
| 'SpaceForDialogIntent', 'SpaceForDialogModeling', | 'SpaceForDialogIntent', 'SpaceForDialogModeling', | ||||
| 'SpaceForDialogStateTracking' | 'SpaceForDialogStateTracking' | ||||
| ], | ], | ||||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||||
| 'task_models': [ | |||||
| 'InformationExtractionModel', 'SequenceClassificationModel', | |||||
| 'SingleBackboneTaskModelBase' | |||||
| ], | |||||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | ||||
| 'gpt3': ['GPT3ForTextGeneration'], | 'gpt3': ['GPT3ForTextGeneration'], | ||||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] | |||||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,106 @@ | |||||
| from typing import Dict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| from modelscope.metainfo import Heads | |||||
| from modelscope.models.base import TorchHead | |||||
| from modelscope.models.builder import HEADS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| @HEADS.register_module( | |||||
| Tasks.information_extraction, module_name=Heads.information_extraction) | |||||
| class InformationExtractionHead(TorchHead): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| config = self.config | |||||
| assert config.get('labels') is not None | |||||
| self.labels = config.labels | |||||
| self.s_layer = nn.Linear(config.hidden_size, 2) # head, tail, bce | |||||
| self.o_layer = nn.Linear(2 * config.hidden_size, 2) # head, tail, bce | |||||
| self.p_layer = nn.Linear(config.hidden_size, | |||||
| len(self.labels)) # label, ce | |||||
| self.mha = nn.MultiheadAttention(config.hidden_size, 4) | |||||
| def forward(self, sequence_output, text, offsets, threshold=0.5): | |||||
| # assert batch size == 1 | |||||
| spos = [] | |||||
| s_head_logits, s_tail_logits = self.s_layer(sequence_output).split( | |||||
| 1, dim=-1) # (b, seq_len, 2) | |||||
| s_head_logits = s_head_logits[0, :, 0].sigmoid() # (seq_len) | |||||
| s_tail_logits = s_tail_logits[0, :, 0].sigmoid() # (seq_len) | |||||
| s_masks, subjects = self._get_masks_and_mentions( | |||||
| text, offsets, s_head_logits, s_tail_logits, None, threshold) | |||||
| for s_mask, subject in zip(s_masks, subjects): | |||||
| masked_sequence_output = sequence_output * s_mask.unsqueeze( | |||||
| 0).unsqueeze(-1) # (b, s, h) | |||||
| subjected_sequence_output = self.mha( | |||||
| sequence_output.permute(1, 0, 2), | |||||
| masked_sequence_output.permute(1, 0, 2), | |||||
| masked_sequence_output.permute(1, 0, | |||||
| 2))[0].permute(1, 0, | |||||
| 2) # (b, s, h) | |||||
| cat_sequence_output = torch.cat( | |||||
| (sequence_output, subjected_sequence_output), dim=-1) | |||||
| o_head_logits, o_tail_logits = self.o_layer( | |||||
| cat_sequence_output).split( | |||||
| 1, dim=-1) | |||||
| o_head_logits = o_head_logits[0, :, 0].sigmoid() # (seq_len) | |||||
| o_tail_logits = o_tail_logits[0, :, 0].sigmoid() # (seq_len) | |||||
| so_masks, objects = self._get_masks_and_mentions( | |||||
| text, offsets, o_head_logits, o_tail_logits, s_mask, threshold) | |||||
| for so_mask, object in zip(so_masks, objects): | |||||
| masked_sequence_output = ( | |||||
| sequence_output * so_mask.unsqueeze(0).unsqueeze(-1)).sum( | |||||
| 1) # (b, h) | |||||
| lengths = so_mask.unsqueeze(0).sum(-1, keepdim=True) # (b, 1) | |||||
| pooled_subject_object = masked_sequence_output / lengths # (b, h) | |||||
| label = self.p_layer(pooled_subject_object).sigmoid().squeeze( | |||||
| 0) | |||||
| for i in range(label.size(-1)): | |||||
| if label[i] > threshold: | |||||
| predicate = self.labels[i] | |||||
| spos.append((subject, predicate, object)) | |||||
| return spos | |||||
| def _get_masks_and_mentions(self, | |||||
| text, | |||||
| offsets, | |||||
| heads, | |||||
| tails, | |||||
| init_mask=None, | |||||
| threshold=0.5): | |||||
| ''' | |||||
| text: str | |||||
| heads: tensor (len(heads)) | |||||
| tails: tensor (len(tails)) | |||||
| ''' | |||||
| seq_len = heads.size(-1) | |||||
| potential_heads = [] | |||||
| for i in range(seq_len - 1): | |||||
| if heads[i] > threshold: | |||||
| potential_heads.append(i) | |||||
| potential_heads.append(seq_len - 1) | |||||
| masks = [] | |||||
| mentions = [] | |||||
| for i in range(len(potential_heads) - 1): | |||||
| head_index = potential_heads[i] | |||||
| tail_index, max_val = None, 0 | |||||
| for j in range(head_index, potential_heads[i + 1]): | |||||
| if tails[j] > max_val and tails[j] > threshold: | |||||
| tail_index = j | |||||
| max_val = tails[j] | |||||
| if tail_index is not None: | |||||
| mask = torch.zeros_like( | |||||
| heads) if init_mask is None else init_mask.clone() | |||||
| mask[head_index:tail_index + 1] = 1 | |||||
| masks.append(mask) # (seq_len) | |||||
| char_head = offsets[head_index][0] | |||||
| char_tail = offsets[tail_index][1] | |||||
| mention = text[char_head:char_tail] | |||||
| mentions.append(mention) | |||||
| return masks, mentions | |||||
| @@ -0,0 +1,26 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .information_extraction import InformationExtractionModel | |||||
| from .sequence_classification import SequenceClassificationModel | |||||
| from .task_model import SingleBackboneTaskModelBase | |||||
| else: | |||||
| _import_structure = { | |||||
| 'information_extraction': ['InformationExtractionModel'], | |||||
| 'sequence_classification': ['SequenceClassificationModel'], | |||||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,49 @@ | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import TaskModels | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp.task_models.task_model import \ | |||||
| SingleBackboneTaskModelBase | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.hub import parse_label_mapping | |||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| __all__ = ['InformationExtractionModel'] | |||||
| @MODELS.register_module( | |||||
| Tasks.information_extraction, | |||||
| module_name=TaskModels.information_extraction) | |||||
| class InformationExtractionModel(SingleBackboneTaskModelBase): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the information extraction model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| backbone_cfg = self.cfg.backbone | |||||
| head_cfg = self.cfg.head | |||||
| self.build_backbone(backbone_cfg) | |||||
| self.build_head(head_cfg) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| outputs = super().forward(input) | |||||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||||
| outputs = self.head.forward(sequence_output, input['text'], | |||||
| input['offsets']) | |||||
| return {OutputKeys.SPO_LIST: outputs} | |||||
| def extract_backbone_outputs(self, outputs): | |||||
| sequence_output = None | |||||
| pooled_output = None | |||||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||||
| return sequence_output, pooled_output | |||||
| @@ -302,8 +302,7 @@ TASK_OUTPUTS = { | |||||
| # "text": "《父老乡亲》是由是由由中国人民解放军海政文工团创作的军旅歌曲,石顺义作词,王锡仁作曲,范琳琳演唱", | # "text": "《父老乡亲》是由是由由中国人民解放军海政文工团创作的军旅歌曲,石顺义作词,王锡仁作曲,范琳琳演唱", | ||||
| # "spo_list": [{"subject": "石顺义", "predicate": "国籍", "object": "中国"}] | # "spo_list": [{"subject": "石顺义", "predicate": "国籍", "object": "中国"}] | ||||
| # } | # } | ||||
| Tasks.relation_extraction: | |||||
| [OutputKeys.UUID, OutputKeys.TEXT, OutputKeys.SPO_LIST], | |||||
| Tasks.relation_extraction: [OutputKeys.SPO_LIST], | |||||
| # translation result for a source sentence | # translation result for a source sentence | ||||
| # { | # { | ||||
| @@ -23,6 +23,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.named_entity_recognition: | Tasks.named_entity_recognition: | ||||
| (Pipelines.named_entity_recognition, | (Pipelines.named_entity_recognition, | ||||
| 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | ||||
| Tasks.information_extraction: | |||||
| (Pipelines.relation_extraction, | |||||
| 'damo/nlp_bert_relation-extraction_chinese-base'), | |||||
| Tasks.sentence_similarity: | Tasks.sentence_similarity: | ||||
| (Pipelines.sentence_similarity, | (Pipelines.sentence_similarity, | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | ||||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | from .document_segmentation_pipeline import DocumentSegmentationPipeline | ||||
| from .fill_mask_pipeline import FillMaskPipeline | from .fill_mask_pipeline import FillMaskPipeline | ||||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | ||||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | ||||
| from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | ||||
| @@ -22,6 +23,7 @@ if TYPE_CHECKING: | |||||
| from .text_classification_pipeline import TextClassificationPipeline | from .text_classification_pipeline import TextClassificationPipeline | ||||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | from .text_error_correction_pipeline import TextErrorCorrectionPipeline | ||||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | ||||
| from .relation_extraction_pipeline import RelationExtractionPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -33,6 +35,7 @@ else: | |||||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | ||||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | ||||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | 'fill_mask_pipeline': ['FillMaskPipeline'], | ||||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||||
| 'single_sentence_classification_pipeline': | 'single_sentence_classification_pipeline': | ||||
| ['SingleSentenceClassificationPipeline'], | ['SingleSentenceClassificationPipeline'], | ||||
| 'pair_sentence_classification_pipeline': | 'pair_sentence_classification_pipeline': | ||||
| @@ -48,7 +51,8 @@ else: | |||||
| 'summarization_pipeline': ['SummarizationPipeline'], | 'summarization_pipeline': ['SummarizationPipeline'], | ||||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | 'text_classification_pipeline': ['TextClassificationPipeline'], | ||||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | ||||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] | |||||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||||
| 'relation_extraction_pipeline': ['RelationExtractionPipeline'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,42 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (Preprocessor, | |||||
| RelationExtractionPreprocessor) | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['InformationExtractionPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.information_extraction, module_name=Pipelines.relation_extraction) | |||||
| class InformationExtractionPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = RelationExtractionPreprocessor( | |||||
| model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| return inputs | |||||
| @@ -22,7 +22,8 @@ if TYPE_CHECKING: | |||||
| PairSentenceClassificationPreprocessor, | PairSentenceClassificationPreprocessor, | ||||
| FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | ||||
| NERPreprocessor, TextErrorCorrectionPreprocessor, | NERPreprocessor, TextErrorCorrectionPreprocessor, | ||||
| FaqQuestionAnsweringPreprocessor) | |||||
| FaqQuestionAnsweringPreprocessor, | |||||
| RelationExtractionPreprocessor) | |||||
| from .slp import DocumentSegmentationPreprocessor | from .slp import DocumentSegmentationPreprocessor | ||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| @@ -51,7 +52,8 @@ else: | |||||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor' | |||||
| 'FaqQuestionAnsweringPreprocessor', | |||||
| 'RelationExtractionPreprocessor' | |||||
| ], | ], | ||||
| 'slp': ['DocumentSegmentationPreprocessor'], | 'slp': ['DocumentSegmentationPreprocessor'], | ||||
| 'space': [ | 'space': [ | ||||
| @@ -22,7 +22,8 @@ __all__ = [ | |||||
| 'PairSentenceClassificationPreprocessor', | 'PairSentenceClassificationPreprocessor', | ||||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' | |||||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||||
| 'RelationExtractionPreprocessor' | |||||
| ] | ] | ||||
| @@ -622,6 +623,52 @@ class NERPreprocessor(Preprocessor): | |||||
| } | } | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.re_tokenizer) | |||||
| class RelationExtractionPreprocessor(Preprocessor): | |||||
| """The tokenizer preprocessor used in normal RE task. | |||||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||||
| """ | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||||
| model_dir, use_fast=True) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| # preprocess the data for the model input | |||||
| text = data | |||||
| output = self.tokenizer([text], return_tensors='pt') | |||||
| return { | |||||
| 'text': text, | |||||
| 'input_ids': output['input_ids'], | |||||
| 'attention_mask': output['attention_mask'], | |||||
| 'offsets': output[0].offsets | |||||
| } | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.text_error_correction) | Fields.nlp, module_name=Preprocessors.text_error_correction) | ||||
| class TextErrorCorrectionPreprocessor(Preprocessor): | class TextErrorCorrectionPreprocessor(Preprocessor): | ||||
| @@ -99,6 +99,7 @@ class NLPTasks(object): | |||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| faq_question_answering = 'faq-question-answering' | faq_question_answering = 'faq-question-answering' | ||||
| conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
| information_extraction = 'information-extraction' | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| @@ -0,0 +1,57 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import torch | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import InformationExtractionModel | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import InformationExtractionPipeline | |||||
| from modelscope.preprocessors import RelationExtractionPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class RelationExtractionTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_bert_relation-extraction_chinese-base' | |||||
| sentence = '高捷,祖籍江苏,本科毕业于东南大学' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = RelationExtractionPreprocessor(cache_path) | |||||
| model = InformationExtractionModel.from_pretrained(cache_path) | |||||
| pipeline1 = InformationExtractionPipeline( | |||||
| model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.information_extraction, model=model, preprocessor=tokenizer) | |||||
| print(f'sentence: {self.sentence}\n' | |||||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = RelationExtractionPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.information_extraction, | |||||
| model=model, | |||||
| preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.information_extraction, model=self.model_id) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.information_extraction) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||