Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9980774master
| @@ -55,7 +55,9 @@ class Models(object): | |||||
| space_modeling = 'space-modeling' | space_modeling = 'space-modeling' | ||||
| star = 'star' | star = 'star' | ||||
| tcrf = 'transformer-crf' | tcrf = 'transformer-crf' | ||||
| transformer_softmax = 'transformer-softmax' | |||||
| lcrf = 'lstm-crf' | lcrf = 'lstm-crf' | ||||
| gcnncrf = 'gcnn-crf' | |||||
| bart = 'bart' | bart = 'bart' | ||||
| gpt3 = 'gpt3' | gpt3 = 'gpt3' | ||||
| plug = 'plug' | plug = 'plug' | ||||
| @@ -82,6 +84,7 @@ class Models(object): | |||||
| class TaskModels(object): | class TaskModels(object): | ||||
| # nlp task | # nlp task | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| token_classification = 'token-classification' | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| @@ -92,6 +95,8 @@ class Heads(object): | |||||
| bert_mlm = 'bert-mlm' | bert_mlm = 'bert-mlm' | ||||
| # roberta mlm | # roberta mlm | ||||
| roberta_mlm = 'roberta-mlm' | roberta_mlm = 'roberta-mlm' | ||||
| # token cls | |||||
| token_classification = 'token-classification' | |||||
| information_extraction = 'information-extraction' | information_extraction = 'information-extraction' | ||||
| @@ -167,6 +172,7 @@ class Pipelines(object): | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| part_of_speech = 'part-of-speech' | |||||
| named_entity_recognition = 'named-entity-recognition' | named_entity_recognition = 'named-entity-recognition' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| @@ -272,6 +278,7 @@ class Preprocessors(object): | |||||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | ||||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
| fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | ||||
| @@ -5,40 +5,39 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .backbones import SbertModel | from .backbones import SbertModel | ||||
| from .heads import SequenceClassificationHead | |||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||||
| from .bert_for_sequence_classification import BertForSequenceClassification | from .bert_for_sequence_classification import BertForSequenceClassification | ||||
| from .bert_for_document_segmentation import BertForDocumentSegmentation | from .bert_for_document_segmentation import BertForDocumentSegmentation | ||||
| from .csanmt_for_translation import CsanmtForTranslation | from .csanmt_for_translation import CsanmtForTranslation | ||||
| from .masked_language import ( | |||||
| StructBertForMaskedLM, | |||||
| VecoForMaskedLM, | |||||
| BertForMaskedLM, | |||||
| DebertaV2ForMaskedLM, | |||||
| ) | |||||
| from .heads import SequenceClassificationHead | |||||
| from .gpt3 import GPT3ForTextGeneration | |||||
| from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||||
| BertForMaskedLM, DebertaV2ForMaskedLM) | |||||
| from .nncrf_for_named_entity_recognition import ( | from .nncrf_for_named_entity_recognition import ( | ||||
| TransformerCRFForNamedEntityRecognition, | TransformerCRFForNamedEntityRecognition, | ||||
| LSTMCRFForNamedEntityRecognition) | LSTMCRFForNamedEntityRecognition) | ||||
| from .token_classification import SbertForTokenClassification | |||||
| from .palm_v2 import PalmForTextGeneration | |||||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||||
| from .star_text_to_sql import StarForTextToSql | |||||
| from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | ||||
| from .space import SpaceForDialogIntent | from .space import SpaceForDialogIntent | ||||
| from .space import SpaceForDialogModeling | from .space import SpaceForDialogModeling | ||||
| from .space import SpaceForDialogStateTracking | from .space import SpaceForDialogStateTracking | ||||
| from .star_text_to_sql import StarForTextToSql | |||||
| from .task_models import (InformationExtractionModel, | from .task_models import (InformationExtractionModel, | ||||
| SingleBackboneTaskModelBase) | |||||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||||
| from .gpt3 import GPT3ForTextGeneration | |||||
| from .plug import PlugForTextGeneration | |||||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||||
| SequenceClassificationModel, | |||||
| SingleBackboneTaskModelBase, | |||||
| TokenClassificationModel) | |||||
| from .token_classification import SbertForTokenClassification | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'star_text_to_sql': ['StarForTextToSql'], | |||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| 'heads': ['SequenceClassificationHead'], | |||||
| 'csanmt_for_translation': ['CsanmtForTranslation'], | |||||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||||
| 'bert_for_sequence_classification': ['BertForSequenceClassification'], | 'bert_for_sequence_classification': ['BertForSequenceClassification'], | ||||
| 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | ||||
| 'csanmt_for_translation': ['CsanmtForTranslation'], | |||||
| 'heads': ['SequenceClassificationHead'], | |||||
| 'gpt3': ['GPT3ForTextGeneration'], | |||||
| 'masked_language': [ | 'masked_language': [ | ||||
| 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', | 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', | ||||
| 'DebertaV2ForMaskedLM' | 'DebertaV2ForMaskedLM' | ||||
| @@ -48,7 +47,8 @@ else: | |||||
| 'LSTMCRFForNamedEntityRecognition' | 'LSTMCRFForNamedEntityRecognition' | ||||
| ], | ], | ||||
| 'palm_v2': ['PalmForTextGeneration'], | 'palm_v2': ['PalmForTextGeneration'], | ||||
| 'token_classification': ['SbertForTokenClassification'], | |||||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||||
| 'star_text_to_sql': ['StarForTextToSql'], | |||||
| 'sequence_classification': | 'sequence_classification': | ||||
| ['VecoForSequenceClassification', 'SbertForSequenceClassification'], | ['VecoForSequenceClassification', 'SbertForSequenceClassification'], | ||||
| 'space': [ | 'space': [ | ||||
| @@ -57,12 +57,9 @@ else: | |||||
| ], | ], | ||||
| 'task_models': [ | 'task_models': [ | ||||
| 'InformationExtractionModel', 'SequenceClassificationModel', | 'InformationExtractionModel', 'SequenceClassificationModel', | ||||
| 'SingleBackboneTaskModelBase' | |||||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | |||||
| ], | ], | ||||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||||
| 'gpt3': ['GPT3ForTextGeneration'], | |||||
| 'plug': ['PlugForTextGeneration'], | |||||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||||
| 'token_classification': ['SbertForTokenClassification'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -19,7 +19,6 @@ class SequenceClassificationHead(TorchHead): | |||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| config = self.config | config = self.config | ||||
| self.num_labels = config.num_labels | self.num_labels = config.num_labels | ||||
| self.config = config | |||||
| classifier_dropout = ( | classifier_dropout = ( | ||||
| config['classifier_dropout'] if config.get('classifier_dropout') | config['classifier_dropout'] if config.get('classifier_dropout') | ||||
| is not None else config['hidden_dropout_prob']) | is not None else config['hidden_dropout_prob']) | ||||
| @@ -0,0 +1,42 @@ | |||||
| from typing import Dict | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| from modelscope.metainfo import Heads | |||||
| from modelscope.models.base import TorchHead | |||||
| from modelscope.models.builder import HEADS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| @HEADS.register_module( | |||||
| Tasks.token_classification, module_name=Heads.token_classification) | |||||
| class TokenClassificationHead(TorchHead): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| config = self.config | |||||
| self.num_labels = config.num_labels | |||||
| classifier_dropout = ( | |||||
| config['classifier_dropout'] if config.get('classifier_dropout') | |||||
| is not None else config['hidden_dropout_prob']) | |||||
| self.dropout = nn.Dropout(classifier_dropout) | |||||
| self.classifier = nn.Linear(config['hidden_size'], | |||||
| config['num_labels']) | |||||
| def forward(self, inputs=None): | |||||
| if isinstance(inputs, dict): | |||||
| assert inputs.get('sequence_output') is not None | |||||
| sequence_output = inputs.get('sequence_output') | |||||
| else: | |||||
| sequence_output = inputs | |||||
| sequence_output = self.dropout(sequence_output) | |||||
| logits = self.classifier(sequence_output) | |||||
| return {OutputKeys.LOGITS: logits} | |||||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||||
| labels) -> Dict[str, torch.Tensor]: | |||||
| logits = outputs[OutputKeys.LOGITS] | |||||
| return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} | |||||
| @@ -85,7 +85,7 @@ class SbertConfig(PretrainedConfig): | |||||
| If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | ||||
| """ | """ | ||||
| model_type = 'sbert' | |||||
| model_type = 'structbert' | |||||
| def __init__(self, | def __init__(self, | ||||
| vocab_size=30522, | vocab_size=30522, | ||||
| @@ -7,12 +7,14 @@ if TYPE_CHECKING: | |||||
| from .information_extraction import InformationExtractionModel | from .information_extraction import InformationExtractionModel | ||||
| from .sequence_classification import SequenceClassificationModel | from .sequence_classification import SequenceClassificationModel | ||||
| from .task_model import SingleBackboneTaskModelBase | from .task_model import SingleBackboneTaskModelBase | ||||
| from .token_classification import TokenClassificationModel | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'information_extraction': ['InformationExtractionModel'], | 'information_extraction': ['InformationExtractionModel'], | ||||
| 'sequence_classification': ['SequenceClassificationModel'], | 'sequence_classification': ['SequenceClassificationModel'], | ||||
| 'task_model': ['SingleBackboneTaskModelBase'], | 'task_model': ['SingleBackboneTaskModelBase'], | ||||
| 'token_classification': ['TokenClassificationModel'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,83 @@ | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import TaskModels | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.nlp.task_models.task_model import \ | |||||
| SingleBackboneTaskModelBase | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.hub import parse_label_mapping | |||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| __all__ = ['TokenClassificationModel'] | |||||
| @MODELS.register_module( | |||||
| Tasks.token_classification, module_name=TaskModels.token_classification) | |||||
| class TokenClassificationModel(SingleBackboneTaskModelBase): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the token classification model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| if 'base_model_prefix' in kwargs: | |||||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||||
| backbone_cfg = self.cfg.backbone | |||||
| head_cfg = self.cfg.head | |||||
| # get the num_labels | |||||
| num_labels = kwargs.get('num_labels') | |||||
| if num_labels is None: | |||||
| label2id = parse_label_mapping(model_dir) | |||||
| if label2id is not None and len(label2id) > 0: | |||||
| num_labels = len(label2id) | |||||
| self.id2label = {id: label for label, id in label2id.items()} | |||||
| head_cfg['num_labels'] = num_labels | |||||
| self.build_backbone(backbone_cfg) | |||||
| self.build_head(head_cfg) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| labels = None | |||||
| if OutputKeys.LABEL in input: | |||||
| labels = input.pop(OutputKeys.LABEL) | |||||
| elif OutputKeys.LABELS in input: | |||||
| labels = input.pop(OutputKeys.LABELS) | |||||
| outputs = super().forward(input) | |||||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||||
| outputs = self.head.forward(sequence_output) | |||||
| if labels in input: | |||||
| loss = self.compute_loss(outputs, labels) | |||||
| outputs.update(loss) | |||||
| return outputs | |||||
| def extract_logits(self, outputs): | |||||
| return outputs[OutputKeys.LOGITS].cpu().detach() | |||||
| def extract_backbone_outputs(self, outputs): | |||||
| sequence_output = None | |||||
| pooled_output = None | |||||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||||
| return sequence_output, pooled_output | |||||
| def compute_loss(self, outputs, labels): | |||||
| loss = self.head.compute_loss(outputs, labels) | |||||
| return loss | |||||
| def postprocess(self, input, **kwargs): | |||||
| logits = self.extract_logits(input) | |||||
| pred = torch.argmax(logits[0], dim=-1) | |||||
| pred = torch_nested_numpify(torch_nested_detach(pred)) | |||||
| logits = torch_nested_numpify(torch_nested_detach(logits)) | |||||
| res = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits} | |||||
| return res | |||||
| @@ -91,6 +91,7 @@ class TokenClassification(TorchModel): | |||||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | ||||
| @MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.token_classification, module_name=Models.structbert) | Tasks.token_classification, module_name=Models.structbert) | ||||
| class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | ||||
| @@ -359,26 +359,20 @@ TASK_OUTPUTS = { | |||||
| # word segmentation result for single sample | # word segmentation result for single sample | ||||
| # { | # { | ||||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | # "output": "今天 天气 不错 , 适合 出去 游玩" | ||||
| # } | |||||
| Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||||
| # part-of-speech result for single sample | |||||
| # [ | |||||
| # {'word': '诸葛', 'label': 'PROPN'}, | |||||
| # {'word': '亮', 'label': 'PROPN'}, | |||||
| # {'word': '发明', 'label': 'VERB'}, | |||||
| # {'word': '八', 'label': 'NUM'}, | |||||
| # {'word': '阵', 'label': 'NOUN'}, | |||||
| # {'word': '图', 'label': 'PART'}, | |||||
| # {'word': '以', 'label': 'ADV'}, | |||||
| # {'word': '利', 'label': 'VERB'}, | |||||
| # {'word': '立营', 'label': 'VERB'}, | |||||
| # {'word': '练兵', 'label': 'VERB'}, | |||||
| # {'word': '.', 'label': 'PUNCT'} | |||||
| # "labels": [ | |||||
| # {'word': '今天', 'label': 'PROPN'}, | |||||
| # {'word': '天气', 'label': 'PROPN'}, | |||||
| # {'word': '不错', 'label': 'VERB'}, | |||||
| # {'word': ',', 'label': 'NUM'}, | |||||
| # {'word': '适合', 'label': 'NOUN'}, | |||||
| # {'word': '出去', 'label': 'PART'}, | |||||
| # {'word': '游玩', 'label': 'ADV'}, | |||||
| # ] | # ] | ||||
| # TODO @wenmeng.zwm support list of result check | |||||
| Tasks.part_of_speech: [OutputKeys.WORD, OutputKeys.LABEL], | |||||
| # } | |||||
| Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||||
| Tasks.part_of_speech: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||||
| # TODO @wenmeng.zwm support list of result check | |||||
| # named entity recognition result for single sample | # named entity recognition result for single sample | ||||
| # { | # { | ||||
| # "output": [ | # "output": [ | ||||
| @@ -20,6 +20,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.word_segmentation: | Tasks.word_segmentation: | ||||
| (Pipelines.word_segmentation, | (Pipelines.word_segmentation, | ||||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | 'damo/nlp_structbert_word-segmentation_chinese-base'), | ||||
| Tasks.token_classification: | |||||
| (Pipelines.part_of_speech, | |||||
| 'damo/nlp_structbert_part-of-speech_chinese-base'), | |||||
| Tasks.named_entity_recognition: | Tasks.named_entity_recognition: | ||||
| (Pipelines.named_entity_recognition, | (Pipelines.named_entity_recognition, | ||||
| 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | ||||
| @@ -9,21 +9,21 @@ if TYPE_CHECKING: | |||||
| from .dialog_modeling_pipeline import DialogModelingPipeline | from .dialog_modeling_pipeline import DialogModelingPipeline | ||||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | ||||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | from .document_segmentation_pipeline import DocumentSegmentationPipeline | ||||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||||
| from .fill_mask_pipeline import FillMaskPipeline | from .fill_mask_pipeline import FillMaskPipeline | ||||
| from .information_extraction_pipeline import InformationExtractionPipeline | from .information_extraction_pipeline import InformationExtractionPipeline | ||||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | ||||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | ||||
| from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | ||||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | from .sequence_classification_pipeline import SequenceClassificationPipeline | ||||
| from .summarization_pipeline import SummarizationPipeline | |||||
| from .text_classification_pipeline import TextClassificationPipeline | |||||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||||
| from .text_generation_pipeline import TextGenerationPipeline | from .text_generation_pipeline import TextGenerationPipeline | ||||
| from .token_classification_pipeline import TokenClassificationPipeline | |||||
| from .translation_pipeline import TranslationPipeline | from .translation_pipeline import TranslationPipeline | ||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .summarization_pipeline import SummarizationPipeline | |||||
| from .text_classification_pipeline import TextClassificationPipeline | |||||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||||
| from .relation_extraction_pipeline import RelationExtractionPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -34,25 +34,25 @@ else: | |||||
| 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | 'dialog_modeling_pipeline': ['DialogModelingPipeline'], | ||||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | ||||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | ||||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | 'fill_mask_pipeline': ['FillMaskPipeline'], | ||||
| 'named_entity_recognition_pipeline': | |||||
| ['NamedEntityRecognitionPipeline'], | |||||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | 'information_extraction_pipeline': ['InformationExtractionPipeline'], | ||||
| 'single_sentence_classification_pipeline': | |||||
| ['SingleSentenceClassificationPipeline'], | |||||
| 'pair_sentence_classification_pipeline': | 'pair_sentence_classification_pipeline': | ||||
| ['PairSentenceClassificationPipeline'], | ['PairSentenceClassificationPipeline'], | ||||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | ||||
| 'single_sentence_classification_pipeline': | |||||
| ['SingleSentenceClassificationPipeline'], | |||||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||||
| 'text_generation_pipeline': ['TextGenerationPipeline'], | 'text_generation_pipeline': ['TextGenerationPipeline'], | ||||
| 'token_classification_pipeline': ['TokenClassificationPipeline'], | |||||
| 'translation_pipeline': ['TranslationPipeline'], | |||||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | ||||
| 'zero_shot_classification_pipeline': | 'zero_shot_classification_pipeline': | ||||
| ['ZeroShotClassificationPipeline'], | ['ZeroShotClassificationPipeline'], | ||||
| 'named_entity_recognition_pipeline': | |||||
| ['NamedEntityRecognitionPipeline'], | |||||
| 'translation_pipeline': ['TranslationPipeline'], | |||||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||||
| 'relation_extraction_pipeline': ['RelationExtractionPipeline'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,92 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (Preprocessor, | |||||
| TokenClassificationPreprocessor) | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['TokenClassificationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.token_classification, module_name=Pipelines.part_of_speech) | |||||
| class TokenClassificationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a token classification pipeline for prediction | |||||
| Args: | |||||
| model (str or Model): A model instance or a model local dir or a model id in the model hub. | |||||
| preprocessor (Preprocessor): a preprocessor instance, must not be None. | |||||
| """ | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||||
| 'model must be a single str or Model' | |||||
| model = model if isinstance(model, | |||||
| Model) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = TokenClassificationPreprocessor( | |||||
| model.model_dir, | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.id2label = getattr(model, 'id2label') | |||||
| assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||||
| 'as a parameter or make sure the preprocessor has the attribute.' | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| text = inputs.pop(OutputKeys.TEXT) | |||||
| with torch.no_grad(): | |||||
| return { | |||||
| **self.model(inputs, **forward_params), OutputKeys.TEXT: text | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| pred_list = inputs['predictions'] | |||||
| labels = [] | |||||
| for pre in pred_list: | |||||
| labels.append(self.id2label[pre]) | |||||
| labels = labels[1:-1] | |||||
| chunks = [] | |||||
| tags = [] | |||||
| chunk = '' | |||||
| assert len(inputs['text']) == len(labels) | |||||
| for token, label in zip(inputs['text'], labels): | |||||
| if label[0] == 'B' or label[0] == 'I': | |||||
| chunk += token | |||||
| else: | |||||
| chunk += token | |||||
| chunks.append(chunk) | |||||
| chunk = '' | |||||
| tags.append(label.split('-')[-1]) | |||||
| if chunk: | |||||
| chunks.append(chunk) | |||||
| tags.append(label.split('-')[-1]) | |||||
| pos_result = [] | |||||
| seg_result = ' '.join(chunks) | |||||
| for chunk, tag in zip(chunks, tags): | |||||
| pos_result.append({OutputKeys.WORD: chunk, OutputKeys.LABEL: tag}) | |||||
| outputs = { | |||||
| OutputKeys.OUTPUT: seg_result, | |||||
| OutputKeys.LABELS: pos_result | |||||
| } | |||||
| return outputs | |||||
| @@ -15,15 +15,14 @@ if TYPE_CHECKING: | |||||
| ImageDenoisePreprocessor) | ImageDenoisePreprocessor) | ||||
| from .kws import WavToLists | from .kws import WavToLists | ||||
| from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) | from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) | ||||
| from .nlp import (Tokenize, SequenceClassificationPreprocessor, | |||||
| TextGenerationPreprocessor, | |||||
| TokenClassificationPreprocessor, | |||||
| SingleSentenceClassificationPreprocessor, | |||||
| PairSentenceClassificationPreprocessor, | |||||
| FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | |||||
| NERPreprocessor, TextErrorCorrectionPreprocessor, | |||||
| FaqQuestionAnsweringPreprocessor, | |||||
| RelationExtractionPreprocessor) | |||||
| from .nlp import ( | |||||
| Tokenize, SequenceClassificationPreprocessor, | |||||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||||
| SingleSentenceClassificationPreprocessor, | |||||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor) | |||||
| from .slp import DocumentSegmentationPreprocessor | from .slp import DocumentSegmentationPreprocessor | ||||
| from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
| DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
| @@ -52,7 +51,7 @@ else: | |||||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| 'FaqQuestionAnsweringPreprocessor', | |||||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||||
| 'RelationExtractionPreprocessor' | 'RelationExtractionPreprocessor' | ||||
| ], | ], | ||||
| 'slp': ['DocumentSegmentationPreprocessor'], | 'slp': ['DocumentSegmentationPreprocessor'], | ||||
| @@ -5,9 +5,11 @@ import uuid | |||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | from typing import Any, Dict, Iterable, Optional, Tuple, Union | ||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| from transformers import AutoTokenizer, BertTokenizerFast | from transformers import AutoTokenizer, BertTokenizerFast | ||||
| from modelscope.metainfo import Models, Preprocessors | from modelscope.metainfo import Models, Preprocessors | ||||
| from modelscope.models.nlp.structbert import SbertTokenizerFast | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.utils.config import ConfigFields | from modelscope.utils.config import ConfigFields | ||||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys | from modelscope.utils.constant import Fields, InputFields, ModeKeys | ||||
| @@ -23,7 +25,7 @@ __all__ = [ | |||||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | ||||
| 'RelationExtractionPreprocessor' | |||||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor' | |||||
| ] | ] | ||||
| @@ -627,6 +629,112 @@ class NERPreprocessor(Preprocessor): | |||||
| } | } | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) | |||||
| class SequenceLabelingPreprocessor(Preprocessor): | |||||
| """The tokenizer preprocessor used in normal NER task. | |||||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||||
| """ | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||||
| if 'lstm' in model_dir or 'gcnn' in model_dir: | |||||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||||
| model_dir, use_fast=False) | |||||
| elif 'structbert' in model_dir: | |||||
| self.tokenizer = SbertTokenizerFast.from_pretrained( | |||||
| model_dir, use_fast=False) | |||||
| else: | |||||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||||
| model_dir, use_fast=False) | |||||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||||
| 'is_split_into_words', False) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| # preprocess the data for the model input | |||||
| text = data | |||||
| if self.is_split_into_words: | |||||
| input_ids = [] | |||||
| label_mask = [] | |||||
| offset_mapping = [] | |||||
| for offset, token in enumerate(list(data)): | |||||
| subtoken_ids = self.tokenizer.encode( | |||||
| token, add_special_tokens=False) | |||||
| if len(subtoken_ids) == 0: | |||||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||||
| input_ids.extend(subtoken_ids) | |||||
| label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) | |||||
| offset_mapping.extend([(offset, offset + 1)] | |||||
| + [(offset + 1, offset + 1)] | |||||
| * (len(subtoken_ids) - 1)) | |||||
| if len(input_ids) >= self.sequence_length - 2: | |||||
| input_ids = input_ids[:self.sequence_length - 2] | |||||
| label_mask = label_mask[:self.sequence_length - 2] | |||||
| offset_mapping = offset_mapping[:self.sequence_length - 2] | |||||
| input_ids = [self.tokenizer.cls_token_id | |||||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||||
| label_mask = [0] + label_mask + [0] | |||||
| attention_mask = [1] * len(input_ids) | |||||
| else: | |||||
| encodings = self.tokenizer( | |||||
| text, | |||||
| add_special_tokens=True, | |||||
| padding=True, | |||||
| truncation=True, | |||||
| max_length=self.sequence_length, | |||||
| return_offsets_mapping=True) | |||||
| input_ids = encodings['input_ids'] | |||||
| attention_mask = encodings['attention_mask'] | |||||
| word_ids = encodings.word_ids() | |||||
| label_mask = [] | |||||
| offset_mapping = [] | |||||
| for i in range(len(word_ids)): | |||||
| if word_ids[i] is None: | |||||
| label_mask.append(0) | |||||
| elif word_ids[i] == word_ids[i - 1]: | |||||
| label_mask.append(0) | |||||
| offset_mapping[-1] = (offset_mapping[-1][0], | |||||
| encodings['offset_mapping'][i][1]) | |||||
| else: | |||||
| label_mask.append(1) | |||||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||||
| if not self.is_transformer_based_model: | |||||
| input_ids = input_ids[1:-1] | |||||
| attention_mask = attention_mask[1:-1] | |||||
| label_mask = label_mask[1:-1] | |||||
| return { | |||||
| 'text': text, | |||||
| 'input_ids': input_ids, | |||||
| 'attention_mask': attention_mask, | |||||
| 'label_mask': label_mask, | |||||
| 'offset_mapping': offset_mapping | |||||
| } | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.re_tokenizer) | Fields.nlp, module_name=Preprocessors.re_tokenizer) | ||||
| class RelationExtractionPreprocessor(Preprocessor): | class RelationExtractionPreprocessor(Preprocessor): | ||||
| @@ -77,19 +77,26 @@ def auto_load(model: Union[str, List[str]]): | |||||
| def get_model_type(model_dir): | def get_model_type(model_dir): | ||||
| """Get the model type from the configuration. | """Get the model type from the configuration. | ||||
| This method will try to get the 'model.type' or 'model.model_type' field from the configuration.json file. | |||||
| If this file does not exist, the method will try to get the 'model_type' field from the config.json. | |||||
| This method will try to get the model type from 'model.backbone.type', | |||||
| 'model.type' or 'model.model_type' field in the configuration.json file. If | |||||
| this file does not exist, the method will try to get the 'model_type' field | |||||
| from the config.json. | |||||
| @param model_dir: The local model dir to use. | |||||
| @return: The model type string, returns None if nothing is found. | |||||
| @param model_dir: The local model dir to use. @return: The model type | |||||
| string, returns None if nothing is found. | |||||
| """ | """ | ||||
| try: | try: | ||||
| configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) | configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) | ||||
| config_file = osp.join(model_dir, 'config.json') | config_file = osp.join(model_dir, 'config.json') | ||||
| if osp.isfile(configuration_file): | if osp.isfile(configuration_file): | ||||
| cfg = Config.from_file(configuration_file) | cfg = Config.from_file(configuration_file) | ||||
| return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \ | |||||
| else cfg.model.type | |||||
| if hasattr(cfg.model, 'backbone'): | |||||
| return cfg.model.backbone.type | |||||
| elif hasattr(cfg.model, | |||||
| 'model_type') and not hasattr(cfg.model, 'type'): | |||||
| return cfg.model.model_type | |||||
| else: | |||||
| return cfg.model.type | |||||
| elif osp.isfile(config_file): | elif osp.isfile(config_file): | ||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| return cfg.model_type if hasattr(cfg, 'model_type') else None | return cfg.model_type if hasattr(cfg, 'model_type') else None | ||||
| @@ -123,13 +130,24 @@ def parse_label_mapping(model_dir): | |||||
| if hasattr(config, ConfigFields.model) and hasattr( | if hasattr(config, ConfigFields.model) and hasattr( | ||||
| config[ConfigFields.model], 'label2id'): | config[ConfigFields.model], 'label2id'): | ||||
| label2id = config[ConfigFields.model].label2id | label2id = config[ConfigFields.model].label2id | ||||
| elif hasattr(config, ConfigFields.model) and hasattr( | |||||
| config[ConfigFields.model], 'id2label'): | |||||
| id2label = config[ConfigFields.model].id2label | |||||
| label2id = {label: id for id, label in id2label.items()} | |||||
| elif hasattr(config, ConfigFields.preprocessor) and hasattr( | elif hasattr(config, ConfigFields.preprocessor) and hasattr( | ||||
| config[ConfigFields.preprocessor], 'label2id'): | config[ConfigFields.preprocessor], 'label2id'): | ||||
| label2id = config[ConfigFields.preprocessor].label2id | label2id = config[ConfigFields.preprocessor].label2id | ||||
| elif hasattr(config, ConfigFields.preprocessor) and hasattr( | |||||
| config[ConfigFields.preprocessor], 'id2label'): | |||||
| id2label = config[ConfigFields.preprocessor].id2label | |||||
| label2id = {label: id for id, label in id2label.items()} | |||||
| if label2id is None: | if label2id is None: | ||||
| config_path = os.path.join(model_dir, 'config.json') | config_path = os.path.join(model_dir, 'config.json') | ||||
| config = Config.from_file(config_path) | config = Config.from_file(config_path) | ||||
| if hasattr(config, 'label2id'): | if hasattr(config, 'label2id'): | ||||
| label2id = config.label2id | label2id = config.label2id | ||||
| elif hasattr(config, 'id2label'): | |||||
| id2label = config.id2label | |||||
| label2id = {label: id for id, label in id2label.items()} | |||||
| return label2id | return label2id | ||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import shutil | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import TokenClassificationModel | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.nlp import TokenClassificationPipeline | |||||
| from modelscope.preprocessors import TokenClassificationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class PartOfSpeechTest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_part-of-speech_chinese-base' | |||||
| sentence = '今天天气不错,适合出去游玩' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = TokenClassificationPreprocessor(cache_path) | |||||
| model = TokenClassificationModel.from_pretrained(cache_path) | |||||
| pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.token_classification, model=model, preprocessor=tokenizer) | |||||
| print(f'sentence: {self.sentence}\n' | |||||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||||
| print() | |||||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = TokenClassificationPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.token_classification, | |||||
| model=model, | |||||
| preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.token_classification, model=self.model_id) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.token_classification) | |||||
| print(pipeline_ins(input=self.sentence)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||