添加东南亚小语种NLP支持,包括:
1. 针对泰语,越南语NER的预处理
2. 基于XLMR-CRF架构的分词模型和pipeline
3. 针对泰语分词的预处理
添加了相应pipeline的unittest
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492404
master
| @@ -70,8 +70,10 @@ class Models(object): | |||
| space_T_en = 'space-T-en' | |||
| space_T_cn = 'space-T-cn' | |||
| tcrf = 'transformer-crf' | |||
| tcrf_wseg = 'transformer-crf-for-word-segmentation' | |||
| transformer_softmax = 'transformer-softmax' | |||
| lcrf = 'lstm-crf' | |||
| lcrf_wseg = 'lstm-crf-for-word-segmentation' | |||
| gcnncrf = 'gcnn-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| @@ -219,8 +221,12 @@ class Pipelines(object): | |||
| domain_classification = 'domain-classification' | |||
| sentence_similarity = 'sentence-similarity' | |||
| word_segmentation = 'word-segmentation' | |||
| multilingual_word_segmentation = 'multilingual-word-segmentation' | |||
| word_segmentation_thai = 'word-segmentation-thai' | |||
| part_of_speech = 'part-of-speech' | |||
| named_entity_recognition = 'named-entity-recognition' | |||
| named_entity_recognition_thai = 'named-entity-recognition-thai' | |||
| named_entity_recognition_viet = 'named-entity-recognition-viet' | |||
| text_generation = 'text-generation' | |||
| text2text_generation = 'text2text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| @@ -343,6 +349,8 @@ class Preprocessors(object): | |||
| text2text_translate_preprocessor = 'text2text-translate-preprocessor' | |||
| token_cls_tokenizer = 'token-cls-tokenizer' | |||
| ner_tokenizer = 'ner-tokenizer' | |||
| thai_ner_tokenizer = 'thai-ner-tokenizer' | |||
| viet_ner_tokenizer = 'viet-ner-tokenizer' | |||
| nli_tokenizer = 'nli-tokenizer' | |||
| sen_cls_tokenizer = 'sen-cls-tokenizer' | |||
| dialog_intent_preprocessor = 'dialog-intent-preprocessor' | |||
| @@ -355,6 +363,7 @@ class Preprocessors(object): | |||
| text_ranking = 'text-ranking' | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| thai_wseg_tokenizer = 'thai-wseg-tokenizer' | |||
| fill_mask = 'fill-mask' | |||
| fill_mask_ponet = 'fill-mask-ponet' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| @@ -5,14 +5,26 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .bart import BartForTextErrorCorrection | |||
| from .bert import ( | |||
| BertForMaskedLM, | |||
| BertForTextRanking, | |||
| BertForSentenceEmbedding, | |||
| BertForSequenceClassification, | |||
| BertForTokenClassification, | |||
| BertForDocumentSegmentation, | |||
| BertModel, | |||
| BertConfig, | |||
| ) | |||
| from .csanmt import CsanmtForTranslation | |||
| from .heads import SequenceClassificationHead | |||
| from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model | |||
| from .gpt_neo import GPTNeoModel | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .heads import SequenceClassificationHead | |||
| from .palm_v2 import PalmForTextGeneration | |||
| from .space_T_en import StarForTextToSql | |||
| from .space_T_cn import TableQuestionAnswering | |||
| from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST | |||
| from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig | |||
| from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST | |||
| from .space_T_cn import TableQuestionAnswering | |||
| from .space_T_en import StarForTextToSql | |||
| from .structbert import ( | |||
| SbertForFaqQuestionAnswering, | |||
| SbertForMaskedLM, | |||
| @@ -22,19 +34,7 @@ if TYPE_CHECKING: | |||
| SbertModel, | |||
| SbertTokenizerFast, | |||
| ) | |||
| from .bert import ( | |||
| BertForMaskedLM, | |||
| BertForTextRanking, | |||
| BertForSentenceEmbedding, | |||
| BertForSequenceClassification, | |||
| BertForTokenClassification, | |||
| BertForDocumentSegmentation, | |||
| BertModel, | |||
| BertConfig, | |||
| ) | |||
| from .veco import VecoModel, VecoConfig, VecoForTokenClassification, \ | |||
| VecoForSequenceClassification, VecoForMaskedLM, VecoTokenizer, VecoTokenizerFast | |||
| from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model | |||
| from .T5 import T5ForConditionalGeneration | |||
| from .task_models import ( | |||
| FeatureExtractionModel, | |||
| InformationExtractionModel, | |||
| @@ -45,9 +45,11 @@ if TYPE_CHECKING: | |||
| TokenClassificationModel, | |||
| TransformerCRFForNamedEntityRecognition, | |||
| ) | |||
| from .veco import (VecoConfig, VecoForMaskedLM, | |||
| VecoForSequenceClassification, | |||
| VecoForTokenClassification, VecoModel, VecoTokenizer, | |||
| VecoTokenizerFast) | |||
| from .T5 import T5ForConditionalGeneration | |||
| from .gpt_neo import GPTNeoModel | |||
| else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| @@ -65,9 +67,13 @@ else: | |||
| 'SbertModel', | |||
| ], | |||
| 'veco': [ | |||
| 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', | |||
| 'VecoForSequenceClassification', 'VecoForMaskedLM', | |||
| 'VecoTokenizer', 'VecoTokenizerFast' | |||
| 'VecoConfig', | |||
| 'VecoForMaskedLM', | |||
| 'VecoForSequenceClassification', | |||
| 'VecoForTokenClassification', | |||
| 'VecoModel', | |||
| 'VecoTokenizer', | |||
| 'VecoTokenizerFast', | |||
| ], | |||
| 'bert': [ | |||
| 'BertForMaskedLM', | |||
| @@ -90,11 +96,13 @@ else: | |||
| 'FeatureExtractionModel', | |||
| 'InformationExtractionModel', | |||
| 'LSTMCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForWordSegmentation', | |||
| 'SequenceClassificationModel', | |||
| 'SingleBackboneTaskModelBase', | |||
| 'TaskModelForTextGeneration', | |||
| 'TokenClassificationModel', | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'TransformerCRFForWordSegmentation', | |||
| ], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| @@ -8,8 +8,13 @@ if TYPE_CHECKING: | |||
| from .feature_extraction import FeatureExtractionModel | |||
| from .fill_mask import FillMaskModel | |||
| from .nncrf_for_named_entity_recognition import ( | |||
| LSTMCRFForNamedEntityRecognition, | |||
| TransformerCRFForNamedEntityRecognition, | |||
| LSTMCRFForNamedEntityRecognition) | |||
| ) | |||
| from .nncrf_for_word_segmentation import ( | |||
| LSTMCRFForWordSegmentation, | |||
| TransformerCRFForWordSegmentation, | |||
| ) | |||
| from .sequence_classification import SequenceClassificationModel | |||
| from .task_model import SingleBackboneTaskModelBase | |||
| from .token_classification import TokenClassificationModel | |||
| @@ -24,6 +29,8 @@ else: | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForNamedEntityRecognition' | |||
| ], | |||
| 'nncrf_for_word_segmentation': | |||
| ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], | |||
| 'sequence_classification': ['SequenceClassificationModel'], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'token_classification': ['TokenClassificationModel'], | |||
| @@ -0,0 +1,639 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. | |||
| # The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) | |||
| # and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. | |||
| import os | |||
| from typing import Any, Dict, List, Optional | |||
| import torch | |||
| import torch.nn as nn | |||
| from transformers import AutoConfig, AutoModel | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import TokenClassifierWithPredictionsOutput | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] | |||
| class SequenceLabelingForWordSegmentation(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model = self.init_model(model_dir, *args, **kwargs) | |||
| model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| self.model.load_state_dict( | |||
| torch.load(model_ckpt, map_location=torch.device('cpu'))) | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| raise NotImplementedError | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| input_tensor = { | |||
| 'input_ids': input['input_ids'], | |||
| 'attention_mask': input['attention_mask'], | |||
| 'label_mask': input['label_mask'], | |||
| } | |||
| output = { | |||
| 'offset_mapping': input['offset_mapping'], | |||
| **input_tensor, | |||
| **self.model(input_tensor) | |||
| } | |||
| return output | |||
| def postprocess(self, input: Dict[str, Any], **kwargs): | |||
| predicts = self.model.decode(input) | |||
| offset_len = len(input['offset_mapping']) | |||
| predictions = torch.narrow( | |||
| predicts, 1, 0, | |||
| offset_len) # index_select only move loc, not resize | |||
| return TokenClassifierWithPredictionsOutput( | |||
| loss=None, | |||
| logits=None, | |||
| hidden_states=None, | |||
| attentions=None, | |||
| offset_mapping=input['offset_mapping'], | |||
| predictions=predictions, | |||
| ) | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) | |||
| class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): | |||
| """This model wraps the TransformerCRF model to register into model sets. | |||
| """ | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| num_labels = self.config.num_labels | |||
| model = TransformerCRF(model_dir, num_labels) | |||
| return model | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) | |||
| class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): | |||
| """This model wraps the LSTMCRF model to register into model sets. | |||
| """ | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| vocab_size = self.config.vocab_size | |||
| embed_width = self.config.embed_width | |||
| num_labels = self.config.num_labels | |||
| lstm_hidden_size = self.config.lstm_hidden_size | |||
| model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) | |||
| return model | |||
| class TransformerCRF(nn.Module): | |||
| """A transformer based model to NER tasks. | |||
| This model will use transformers' backbones as its backbone. | |||
| """ | |||
| def __init__(self, model_dir, num_labels, **kwargs): | |||
| super(TransformerCRF, self).__init__() | |||
| self.encoder = AutoModel.from_pretrained(model_dir) | |||
| self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) | |||
| self.crf = CRF(num_labels, batch_first=True) | |||
| def forward(self, inputs): | |||
| embed = self.encoder( | |||
| inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] | |||
| logits = self.linear(embed) | |||
| if 'label_mask' in inputs: | |||
| mask = inputs['label_mask'] | |||
| masked_lengths = mask.sum(-1).long() | |||
| masked_logits = torch.zeros_like(logits) | |||
| for i in range(len(mask)): | |||
| masked_logits[ | |||
| i, :masked_lengths[i], :] = logits[i].masked_select( | |||
| mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) | |||
| logits = masked_logits | |||
| outputs = {'logits': logits} | |||
| return outputs | |||
| def decode(self, inputs): | |||
| seq_lens = inputs['label_mask'].sum(-1).long() | |||
| mask = torch.arange( | |||
| inputs['label_mask'].shape[1], | |||
| device=seq_lens.device)[None, :] < seq_lens[:, None] | |||
| predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) | |||
| return predicts | |||
| class LSTMCRF(nn.Module): | |||
| """ | |||
| A standard bilstm-crf model for fast prediction. | |||
| """ | |||
| def __init__(self, | |||
| vocab_size, | |||
| embed_width, | |||
| num_labels, | |||
| lstm_hidden_size=100, | |||
| **kwargs): | |||
| super(LSTMCRF, self).__init__() | |||
| self.embedding = Embedding(vocab_size, embed_width) | |||
| self.lstm = nn.LSTM( | |||
| embed_width, | |||
| lstm_hidden_size, | |||
| num_layers=1, | |||
| bidirectional=True, | |||
| batch_first=True) | |||
| self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) | |||
| self.crf = CRF(num_labels, batch_first=True) | |||
| def forward(self, inputs): | |||
| embedding = self.embedding(inputs['input_ids']) | |||
| lstm_output, _ = self.lstm(embedding) | |||
| logits = self.ffn(lstm_output) | |||
| if 'label_mask' in inputs: | |||
| mask = inputs['label_mask'] | |||
| masked_lengths = mask.sum(-1).long() | |||
| masked_logits = torch.zeros_like(logits) | |||
| for i in range(len(mask)): | |||
| masked_logits[ | |||
| i, :masked_lengths[i], :] = logits[i].masked_select( | |||
| mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) | |||
| logits = masked_logits | |||
| outputs = {'logits': logits} | |||
| return outputs | |||
| def decode(self, inputs): | |||
| seq_lens = inputs['label_mask'].sum(-1).long() | |||
| mask = torch.arange( | |||
| inputs['label_mask'].shape[1], | |||
| device=seq_lens.device)[None, :] < seq_lens[:, None] | |||
| predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) | |||
| outputs = {'predicts': predicts} | |||
| return outputs | |||
| class CRF(nn.Module): | |||
| """Conditional random field. | |||
| This module implements a conditional random field [LMP01]_. The forward computation | |||
| of this class computes the log likelihood of the given sequence of tags and | |||
| emission score tensor. This class also has `~CRF.decode` method which finds | |||
| the best tag sequence given an emission score tensor using `Viterbi algorithm`_. | |||
| Args: | |||
| num_tags: Number of tags. | |||
| batch_first: Whether the first dimension corresponds to the size of a minibatch. | |||
| Attributes: | |||
| start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size | |||
| ``(num_tags,)``. | |||
| end_transitions (`~torch.nn.Parameter`): End transition score tensor of size | |||
| ``(num_tags,)``. | |||
| transitions (`~torch.nn.Parameter`): Transition score tensor of size | |||
| ``(num_tags, num_tags)``. | |||
| .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). | |||
| "Conditional random fields: Probabilistic models for segmenting and | |||
| labeling sequence data". *Proc. 18th International Conf. on Machine | |||
| Learning*. Morgan Kaufmann. pp. 282–289. | |||
| .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm | |||
| """ | |||
| def __init__(self, num_tags: int, batch_first: bool = False) -> None: | |||
| if num_tags <= 0: | |||
| raise ValueError(f'invalid number of tags: {num_tags}') | |||
| super().__init__() | |||
| self.num_tags = num_tags | |||
| self.batch_first = batch_first | |||
| self.start_transitions = nn.Parameter(torch.empty(num_tags)) | |||
| self.end_transitions = nn.Parameter(torch.empty(num_tags)) | |||
| self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) | |||
| self.reset_parameters() | |||
| def reset_parameters(self) -> None: | |||
| """Initialize the transition parameters. | |||
| The parameters will be initialized randomly from a uniform distribution | |||
| between -0.1 and 0.1. | |||
| """ | |||
| nn.init.uniform_(self.start_transitions, -0.1, 0.1) | |||
| nn.init.uniform_(self.end_transitions, -0.1, 0.1) | |||
| nn.init.uniform_(self.transitions, -0.1, 0.1) | |||
| def __repr__(self) -> str: | |||
| return f'{self.__class__.__name__}(num_tags={self.num_tags})' | |||
| def forward(self, | |||
| emissions: torch.Tensor, | |||
| tags: torch.LongTensor, | |||
| mask: Optional[torch.ByteTensor] = None, | |||
| reduction: str = 'mean') -> torch.Tensor: | |||
| """Compute the conditional log likelihood of a sequence of tags given emission scores. | |||
| Args: | |||
| emissions (`~torch.Tensor`): Emission score tensor of size | |||
| ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, | |||
| ``(batch_size, seq_length, num_tags)`` otherwise. | |||
| tags (`~torch.LongTensor`): Sequence of tags tensor of size | |||
| ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, | |||
| ``(batch_size, seq_length)`` otherwise. | |||
| mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` | |||
| if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. | |||
| reduction: Specifies the reduction to apply to the output: | |||
| ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. | |||
| ``sum``: the output will be summed over batches. ``mean``: the output will be | |||
| averaged over batches. ``token_mean``: the output will be averaged over tokens. | |||
| Returns: | |||
| `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if | |||
| reduction is ``none``, ``()`` otherwise. | |||
| """ | |||
| if reduction not in ('none', 'sum', 'mean', 'token_mean'): | |||
| raise ValueError(f'invalid reduction: {reduction}') | |||
| if mask is None: | |||
| mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) | |||
| if mask.dtype != torch.uint8: | |||
| mask = mask.byte() | |||
| self._validate(emissions, tags=tags, mask=mask) | |||
| if self.batch_first: | |||
| emissions = emissions.transpose(0, 1) | |||
| tags = tags.transpose(0, 1) | |||
| mask = mask.transpose(0, 1) | |||
| # shape: (batch_size,) | |||
| numerator = self._compute_score(emissions, tags, mask) | |||
| # shape: (batch_size,) | |||
| denominator = self._compute_normalizer(emissions, mask) | |||
| # shape: (batch_size,) | |||
| llh = numerator - denominator | |||
| if reduction == 'none': | |||
| return llh | |||
| if reduction == 'sum': | |||
| return llh.sum() | |||
| if reduction == 'mean': | |||
| return llh.mean() | |||
| return llh.sum() / mask.float().sum() | |||
| def decode(self, | |||
| emissions: torch.Tensor, | |||
| mask: Optional[torch.ByteTensor] = None, | |||
| nbest: Optional[int] = None, | |||
| pad_tag: Optional[int] = None) -> List[List[List[int]]]: | |||
| """Find the most likely tag sequence using Viterbi algorithm. | |||
| Args: | |||
| emissions (`~torch.Tensor`): Emission score tensor of size | |||
| ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, | |||
| ``(batch_size, seq_length, num_tags)`` otherwise. | |||
| mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` | |||
| if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. | |||
| nbest (`int`): Number of most probable paths for each sequence | |||
| pad_tag (`int`): Tag at padded positions. Often input varies in length and | |||
| the length will be padded to the maximum length in the batch. Tags at | |||
| the padded positions will be assigned with a padding tag, i.e. `pad_tag` | |||
| Returns: | |||
| A PyTorch tensor of the best tag sequence for each batch of shape | |||
| (nbest, batch_size, seq_length) | |||
| """ | |||
| if nbest is None: | |||
| nbest = 1 | |||
| if mask is None: | |||
| mask = torch.ones( | |||
| emissions.shape[:2], | |||
| dtype=torch.uint8, | |||
| device=emissions.device) | |||
| if mask.dtype != torch.uint8: | |||
| mask = mask.byte() | |||
| self._validate(emissions, mask=mask) | |||
| if self.batch_first: | |||
| emissions = emissions.transpose(0, 1) | |||
| mask = mask.transpose(0, 1) | |||
| if nbest == 1: | |||
| return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) | |||
| return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) | |||
| def _validate(self, | |||
| emissions: torch.Tensor, | |||
| tags: Optional[torch.LongTensor] = None, | |||
| mask: Optional[torch.ByteTensor] = None) -> None: | |||
| if emissions.dim() != 3: | |||
| raise ValueError( | |||
| f'emissions must have dimension of 3, got {emissions.dim()}') | |||
| if emissions.size(2) != self.num_tags: | |||
| raise ValueError( | |||
| f'expected last dimension of emissions is {self.num_tags}, ' | |||
| f'got {emissions.size(2)}') | |||
| if tags is not None: | |||
| if emissions.shape[:2] != tags.shape: | |||
| raise ValueError( | |||
| 'the first two dimensions of emissions and tags must match, ' | |||
| f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' | |||
| ) | |||
| if mask is not None: | |||
| if emissions.shape[:2] != mask.shape: | |||
| raise ValueError( | |||
| 'the first two dimensions of emissions and mask must match, ' | |||
| f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' | |||
| ) | |||
| no_empty_seq = not self.batch_first and mask[0].all() | |||
| no_empty_seq_bf = self.batch_first and mask[:, 0].all() | |||
| if not no_empty_seq and not no_empty_seq_bf: | |||
| raise ValueError('mask of the first timestep must all be on') | |||
| def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, | |||
| mask: torch.ByteTensor) -> torch.Tensor: | |||
| # emissions: (seq_length, batch_size, num_tags) | |||
| # tags: (seq_length, batch_size) | |||
| # mask: (seq_length, batch_size) | |||
| seq_length, batch_size = tags.shape | |||
| mask = mask.float() | |||
| # Start transition score and first emission | |||
| # shape: (batch_size,) | |||
| score = self.start_transitions[tags[0]] | |||
| score += emissions[0, torch.arange(batch_size), tags[0]] | |||
| for i in range(1, seq_length): | |||
| # Transition score to next tag, only added if next timestep is valid (mask == 1) | |||
| # shape: (batch_size,) | |||
| score += self.transitions[tags[i - 1], tags[i]] * mask[i] | |||
| # Emission score for next tag, only added if next timestep is valid (mask == 1) | |||
| # shape: (batch_size,) | |||
| score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] | |||
| # End transition score | |||
| # shape: (batch_size,) | |||
| seq_ends = mask.long().sum(dim=0) - 1 | |||
| # shape: (batch_size,) | |||
| last_tags = tags[seq_ends, torch.arange(batch_size)] | |||
| # shape: (batch_size,) | |||
| score += self.end_transitions[last_tags] | |||
| return score | |||
| def _compute_normalizer(self, emissions: torch.Tensor, | |||
| mask: torch.ByteTensor) -> torch.Tensor: | |||
| # emissions: (seq_length, batch_size, num_tags) | |||
| # mask: (seq_length, batch_size) | |||
| seq_length = emissions.size(0) | |||
| # Start transition score and first emission; score has size of | |||
| # (batch_size, num_tags) where for each batch, the j-th column stores | |||
| # the score that the first timestep has tag j | |||
| # shape: (batch_size, num_tags) | |||
| score = self.start_transitions + emissions[0] | |||
| for i in range(1, seq_length): | |||
| # Broadcast score for every possible next tag | |||
| # shape: (batch_size, num_tags, 1) | |||
| broadcast_score = score.unsqueeze(2) | |||
| # Broadcast emission score for every possible current tag | |||
| # shape: (batch_size, 1, num_tags) | |||
| broadcast_emissions = emissions[i].unsqueeze(1) | |||
| # Compute the score tensor of size (batch_size, num_tags, num_tags) where | |||
| # for each sample, entry at row i and column j stores the sum of scores of all | |||
| # possible tag sequences so far that end with transitioning from tag i to tag j | |||
| # and emitting | |||
| # shape: (batch_size, num_tags, num_tags) | |||
| next_score = broadcast_score + self.transitions + broadcast_emissions | |||
| # Sum over all possible current tags, but we're in score space, so a sum | |||
| # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of | |||
| # all possible tag sequences so far, that end in tag i | |||
| # shape: (batch_size, num_tags) | |||
| next_score = torch.logsumexp(next_score, dim=1) | |||
| # Set score to the next score if this timestep is valid (mask == 1) | |||
| # shape: (batch_size, num_tags) | |||
| score = torch.where(mask[i].unsqueeze(1), next_score, score) | |||
| # End transition score | |||
| # shape: (batch_size, num_tags) | |||
| score += self.end_transitions | |||
| # Sum (log-sum-exp) over all possible tags | |||
| # shape: (batch_size,) | |||
| return torch.logsumexp(score, dim=1) | |||
| def _viterbi_decode(self, | |||
| emissions: torch.FloatTensor, | |||
| mask: torch.ByteTensor, | |||
| pad_tag: Optional[int] = None) -> List[List[int]]: | |||
| # emissions: (seq_length, batch_size, num_tags) | |||
| # mask: (seq_length, batch_size) | |||
| # return: (batch_size, seq_length) | |||
| if pad_tag is None: | |||
| pad_tag = 0 | |||
| device = emissions.device | |||
| seq_length, batch_size = mask.shape | |||
| # Start transition and first emission | |||
| # shape: (batch_size, num_tags) | |||
| score = self.start_transitions + emissions[0] | |||
| history_idx = torch.zeros((seq_length, batch_size, self.num_tags), | |||
| dtype=torch.long, | |||
| device=device) | |||
| oor_idx = torch.zeros((batch_size, self.num_tags), | |||
| dtype=torch.long, | |||
| device=device) | |||
| oor_tag = torch.full((seq_length, batch_size), | |||
| pad_tag, | |||
| dtype=torch.long, | |||
| device=device) | |||
| # - score is a tensor of size (batch_size, num_tags) where for every batch, | |||
| # value at column j stores the score of the best tag sequence so far that ends | |||
| # with tag j | |||
| # - history_idx saves where the best tags candidate transitioned from; this is used | |||
| # when we trace back the best tag sequence | |||
| # - oor_idx saves the best tags candidate transitioned from at the positions | |||
| # where mask is 0, i.e. out of range (oor) | |||
| # Viterbi algorithm recursive case: we compute the score of the best tag sequence | |||
| # for every possible next tag | |||
| for i in range(1, seq_length): | |||
| # Broadcast viterbi score for every possible next tag | |||
| # shape: (batch_size, num_tags, 1) | |||
| broadcast_score = score.unsqueeze(2) | |||
| # Broadcast emission score for every possible current tag | |||
| # shape: (batch_size, 1, num_tags) | |||
| broadcast_emission = emissions[i].unsqueeze(1) | |||
| # Compute the score tensor of size (batch_size, num_tags, num_tags) where | |||
| # for each sample, entry at row i and column j stores the score of the best | |||
| # tag sequence so far that ends with transitioning from tag i to tag j and emitting | |||
| # shape: (batch_size, num_tags, num_tags) | |||
| next_score = broadcast_score + self.transitions + broadcast_emission | |||
| # Find the maximum score over all possible current tag | |||
| # shape: (batch_size, num_tags) | |||
| next_score, indices = next_score.max(dim=1) | |||
| # Set score to the next score if this timestep is valid (mask == 1) | |||
| # and save the index that produces the next score | |||
| # shape: (batch_size, num_tags) | |||
| score = torch.where(mask[i].unsqueeze(-1), next_score, score) | |||
| indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) | |||
| history_idx[i - 1] = indices | |||
| # End transition score | |||
| # shape: (batch_size, num_tags) | |||
| end_score = score + self.end_transitions | |||
| _, end_tag = end_score.max(dim=1) | |||
| # shape: (batch_size,) | |||
| seq_ends = mask.long().sum(dim=0) - 1 | |||
| # insert the best tag at each sequence end (last position with mask == 1) | |||
| history_idx = history_idx.transpose(1, 0).contiguous() | |||
| history_idx.scatter_( | |||
| 1, | |||
| seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), | |||
| end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) | |||
| history_idx = history_idx.transpose(1, 0).contiguous() | |||
| # The most probable path for each sequence | |||
| best_tags_arr = torch.zeros((seq_length, batch_size), | |||
| dtype=torch.long, | |||
| device=device) | |||
| best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) | |||
| for idx in range(seq_length - 1, -1, -1): | |||
| best_tags = torch.gather(history_idx[idx], 1, best_tags) | |||
| best_tags_arr[idx] = best_tags.data.view(batch_size) | |||
| return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) | |||
| def _viterbi_decode_nbest( | |||
| self, | |||
| emissions: torch.FloatTensor, | |||
| mask: torch.ByteTensor, | |||
| nbest: int, | |||
| pad_tag: Optional[int] = None) -> List[List[List[int]]]: | |||
| # emissions: (seq_length, batch_size, num_tags) | |||
| # mask: (seq_length, batch_size) | |||
| # return: (nbest, batch_size, seq_length) | |||
| if pad_tag is None: | |||
| pad_tag = 0 | |||
| device = emissions.device | |||
| seq_length, batch_size = mask.shape | |||
| # Start transition and first emission | |||
| # shape: (batch_size, num_tags) | |||
| score = self.start_transitions + emissions[0] | |||
| history_idx = torch.zeros( | |||
| (seq_length, batch_size, self.num_tags, nbest), | |||
| dtype=torch.long, | |||
| device=device) | |||
| oor_idx = torch.zeros((batch_size, self.num_tags, nbest), | |||
| dtype=torch.long, | |||
| device=device) | |||
| oor_tag = torch.full((seq_length, batch_size, nbest), | |||
| pad_tag, | |||
| dtype=torch.long, | |||
| device=device) | |||
| # + score is a tensor of size (batch_size, num_tags) where for every batch, | |||
| # value at column j stores the score of the best tag sequence so far that ends | |||
| # with tag j | |||
| # + history_idx saves where the best tags candidate transitioned from; this is used | |||
| # when we trace back the best tag sequence | |||
| # - oor_idx saves the best tags candidate transitioned from at the positions | |||
| # where mask is 0, i.e. out of range (oor) | |||
| # Viterbi algorithm recursive case: we compute the score of the best tag sequence | |||
| # for every possible next tag | |||
| for i in range(1, seq_length): | |||
| if i == 1: | |||
| broadcast_score = score.unsqueeze(-1) | |||
| broadcast_emission = emissions[i].unsqueeze(1) | |||
| # shape: (batch_size, num_tags, num_tags) | |||
| next_score = broadcast_score + self.transitions + broadcast_emission | |||
| else: | |||
| broadcast_score = score.unsqueeze(-1) | |||
| broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) | |||
| # shape: (batch_size, num_tags, nbest, num_tags) | |||
| next_score = broadcast_score + self.transitions.unsqueeze( | |||
| 1) + broadcast_emission | |||
| # Find the top `nbest` maximum score over all possible current tag | |||
| # shape: (batch_size, nbest, num_tags) | |||
| next_score, indices = next_score.view(batch_size, -1, | |||
| self.num_tags).topk( | |||
| nbest, dim=1) | |||
| if i == 1: | |||
| score = score.unsqueeze(-1).expand(-1, -1, nbest) | |||
| indices = indices * nbest | |||
| # convert to shape: (batch_size, num_tags, nbest) | |||
| next_score = next_score.transpose(2, 1) | |||
| indices = indices.transpose(2, 1) | |||
| # Set score to the next score if this timestep is valid (mask == 1) | |||
| # and save the index that produces the next score | |||
| # shape: (batch_size, num_tags, nbest) | |||
| score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), | |||
| next_score, score) | |||
| indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, | |||
| oor_idx) | |||
| history_idx[i - 1] = indices | |||
| # End transition score shape: (batch_size, num_tags, nbest) | |||
| end_score = score + self.end_transitions.unsqueeze(-1) | |||
| _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) | |||
| # shape: (batch_size,) | |||
| seq_ends = mask.long().sum(dim=0) - 1 | |||
| # insert the best tag at each sequence end (last position with mask == 1) | |||
| history_idx = history_idx.transpose(1, 0).contiguous() | |||
| history_idx.scatter_( | |||
| 1, | |||
| seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), | |||
| end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) | |||
| history_idx = history_idx.transpose(1, 0).contiguous() | |||
| # The most probable path for each sequence | |||
| best_tags_arr = torch.zeros((seq_length, batch_size, nbest), | |||
| dtype=torch.long, | |||
| device=device) | |||
| best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ | |||
| .view(1, -1).expand(batch_size, -1) | |||
| for idx in range(seq_length - 1, -1, -1): | |||
| best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, | |||
| best_tags) | |||
| best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest | |||
| return torch.where(mask.unsqueeze(-1), best_tags_arr, | |||
| oor_tag).permute(2, 1, 0) | |||
| class Embedding(nn.Module): | |||
| def __init__(self, vocab_size, embed_width): | |||
| super(Embedding, self).__init__() | |||
| self.embedding = nn.Embedding(vocab_size, embed_width) | |||
| def forward(self, input_ids): | |||
| return self.embedding(input_ids) | |||
| @@ -16,7 +16,9 @@ if TYPE_CHECKING: | |||
| from .feature_extraction_pipeline import FeatureExtractionPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline, \ | |||
| NamedEntityRecognitionThaiPipeline, \ | |||
| NamedEntityRecognitionVietPipeline | |||
| from .text_ranking_pipeline import TextRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| @@ -29,6 +31,8 @@ if TYPE_CHECKING: | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | |||
| WordSegmentationThaiPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -46,8 +50,11 @@ else: | |||
| 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], | |||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'named_entity_recognition_pipeline': [ | |||
| 'NamedEntityRecognitionPipeline', | |||
| 'NamedEntityRecognitionThaiPipeline', | |||
| 'NamedEntityRecognitionVietPipeline' | |||
| ], | |||
| 'text_ranking_pipeline': ['TextRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| @@ -64,6 +71,10 @@ else: | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| 'zero_shot_classification_pipeline': | |||
| ['ZeroShotClassificationPipeline'], | |||
| 'multilingual_word_segmentation_pipeline': [ | |||
| 'MultilingualWordSegmentationPipeline', | |||
| 'WordSegmentationThaiPipeline' | |||
| ], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,125 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| TokenClassificationPreprocessor, | |||
| WordSegmentationPreprocessorThai) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = [ | |||
| 'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' | |||
| ] | |||
| @PIPELINES.register_module( | |||
| Tasks.word_segmentation, | |||
| module_name=Pipelines.multilingual_word_segmentation) | |||
| class MultilingualWordSegmentationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported word segmentation task, or a | |||
| model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. | |||
| To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TokenClassificationPreprocessor( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.config = model.config | |||
| assert len(self.config.id2label) > 0 | |||
| self.id2label = self.config.id2label | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| text = inputs.pop(OutputKeys.TEXT) | |||
| with torch.no_grad(): | |||
| return { | |||
| **super().forward(inputs, **forward_params), OutputKeys.TEXT: | |||
| text | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| text = inputs['text'] | |||
| offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | |||
| labels = [ | |||
| self.id2label[x] | |||
| for x in inputs['predictions'].squeeze(0).cpu().numpy() | |||
| ] | |||
| entities = [] | |||
| entity = {} | |||
| for label, offsets in zip(labels, offset_mapping): | |||
| if label[0] in 'BS': | |||
| if entity: | |||
| entity['span'] = text[entity['start']:entity['end']] | |||
| entities.append(entity) | |||
| entity = { | |||
| 'type': label[2:], | |||
| 'start': offsets[0], | |||
| 'end': offsets[1] | |||
| } | |||
| if label[0] in 'IES': | |||
| if entity: | |||
| entity['end'] = offsets[1] | |||
| if label[0] in 'ES': | |||
| if entity: | |||
| entity['span'] = text[entity['start']:entity['end']] | |||
| entities.append(entity) | |||
| entity = {} | |||
| if entity: | |||
| entity['span'] = text[entity['start']:entity['end']] | |||
| entities.append(entity) | |||
| word_segments = [entity['span'] for entity in entities] | |||
| outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} | |||
| return outputs | |||
| @PIPELINES.register_module( | |||
| Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) | |||
| class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = WordSegmentationPreprocessorThai( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| outputs = super().postprocess(inputs, **postprocess_params) | |||
| word_segments = outputs[OutputKeys.OUTPUT] | |||
| word_segments = [seg.replace(' ', '') for seg in word_segments] | |||
| return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} | |||
| @@ -9,13 +9,17 @@ from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, | |||
| Preprocessor, | |||
| TokenClassificationPreprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| __all__ = ['NamedEntityRecognitionPipeline'] | |||
| __all__ = [ | |||
| 'NamedEntityRecognitionPipeline', 'NamedEntityRecognitionThaiPipeline', | |||
| 'NamedEntityRecognitionVietPipeline' | |||
| ] | |||
| @PIPELINES.register_module( | |||
| @@ -126,3 +130,39 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||
| else: | |||
| outputs = {OutputKeys.OUTPUT: chunks} | |||
| return outputs | |||
| @PIPELINES.register_module( | |||
| Tasks.named_entity_recognition, | |||
| module_name=Pipelines.named_entity_recognition_thai) | |||
| class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = NERPreprocessorThai( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| @PIPELINES.register_module( | |||
| Tasks.named_entity_recognition, | |||
| module_name=Pipelines.named_entity_recognition_viet) | |||
| class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| **kwargs): | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = NERPreprocessorViet( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| @@ -28,7 +28,8 @@ if TYPE_CHECKING: | |||
| SentencePiecePreprocessor, DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, DialogStateTrackingPreprocessor, | |||
| ConversationalTextToSqlPreprocessor, | |||
| TableQuestionAnsweringPreprocessor) | |||
| TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | |||
| NERPreprocessorThai, WordSegmentationPreprocessorThai) | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| else: | |||
| @@ -58,6 +59,8 @@ else: | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | |||
| 'NERPreprocessorViet', 'NERPreprocessorThai', | |||
| 'WordSegmentationPreprocessorThai', | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| 'DialogStateTrackingPreprocessor', | |||
| 'ConversationalTextToSqlPreprocessor', | |||
| @@ -20,6 +20,8 @@ if TYPE_CHECKING: | |||
| from .text2text_generation_preprocessor import Text2TextGenerationPreprocessor | |||
| from .token_classification_preprocessor import TokenClassificationPreprocessor, \ | |||
| WordSegmentationBlankSetToLabelPreprocessor | |||
| from .token_classification_thai_preprocessor import WordSegmentationPreprocessorThai, NERPreprocessorThai | |||
| from .token_classification_viet_preprocessor import NERPreprocessorViet | |||
| from .zero_shot_classification_reprocessor import ZeroShotClassificationPreprocessor | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| @@ -60,10 +62,20 @@ else: | |||
| 'text_error_correction': [ | |||
| 'TextErrorCorrectionPreprocessor', | |||
| ], | |||
| 'token_classification_thai_preprocessor': [ | |||
| 'NERPreprocessorThai', | |||
| 'WordSegmentationPreprocessorThai', | |||
| ], | |||
| 'token_classification_viet_preprocessor': [ | |||
| 'NERPreprocessorViet', | |||
| ], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| 'DialogStateTrackingPreprocessor', 'InputFeatures', | |||
| 'MultiWOZBPETextField', 'IntentBPETextField' | |||
| 'DialogIntentPredictionPreprocessor', | |||
| 'DialogModelingPreprocessor', | |||
| 'DialogStateTrackingPreprocessor', | |||
| 'InputFeatures', | |||
| 'MultiWOZBPETextField', | |||
| 'IntentBPETextField', | |||
| ], | |||
| 'space_T_en': ['ConversationalTextToSqlPreprocessor'], | |||
| 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Tuple, Union | |||
| import torch | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.constant import Fields, ModeKeys | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .token_classification_preprocessor import TokenClassificationPreprocessor | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.thai_ner_tokenizer) | |||
| class NERPreprocessorThai(TokenClassificationPreprocessor): | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| from pythainlp import word_tokenize | |||
| segmented_data = ' '.join([ | |||
| w.strip(' ') for w in word_tokenize(text=data, engine='newmm') | |||
| if w.strip(' ') != '' | |||
| ]) | |||
| output = super().__call__(segmented_data) | |||
| return output | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.thai_wseg_tokenizer) | |||
| class WordSegmentationPreprocessorThai(TokenClassificationPreprocessor): | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| import regex | |||
| data = regex.findall(r'\X', data) | |||
| data = ' '.join([char for char in data]) | |||
| output = super().__call__(data) | |||
| return output | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Tuple, Union | |||
| import torch | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.constant import Fields, ModeKeys | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .token_classification_preprocessor import TokenClassificationPreprocessor | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.viet_ner_tokenizer) | |||
| class NERPreprocessorViet(TokenClassificationPreprocessor): | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| from pyvi import ViTokenizer | |||
| seg_words = [ | |||
| t.strip(' ') for t in ViTokenizer.tokenize(data).split(' ') | |||
| if t.strip(' ') != '' | |||
| ] | |||
| raw_words = [] | |||
| for w in seg_words: | |||
| raw_words.extend(w.split('_')) | |||
| segmented_data = ' '.join(raw_words) | |||
| output = super().__call__(segmented_data) | |||
| return output | |||
| @@ -4,6 +4,8 @@ megatron_util | |||
| pai-easynlp | |||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | |||
| protobuf>=3.19.0,<3.21.0 | |||
| pythainlp | |||
| pyvi | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| @@ -0,0 +1,102 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, | |||
| TransformerCRFForNamedEntityRecognition) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import (NamedEntityRecognitionThaiPipeline, | |||
| NamedEntityRecognitionVietPipeline) | |||
| from modelscope.preprocessors import NERPreprocessorThai, NERPreprocessorViet | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class MultilingualNamedEntityRecognitionTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.named_entity_recognition | |||
| self.model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' | |||
| thai_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' | |||
| thai_sentence = 'เครื่องชั่งดิจิตอลแบบตั้งพื้น150kg.' | |||
| viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' | |||
| viet_sentence = 'Nón vành dễ thương cho bé gái' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download_thai(self): | |||
| cache_path = snapshot_download(self.thai_tcrf_model_id) | |||
| tokenizer = NERPreprocessorThai(cache_path) | |||
| model = TransformerCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NamedEntityRecognitionThaiPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(f'thai_sentence: {self.thai_sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.thai_sentence)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.thai_sentence)}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_from_modelhub_thai(self): | |||
| model = Model.from_pretrained(self.thai_tcrf_model_id) | |||
| tokenizer = NERPreprocessorThai(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.thai_sentence)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_name_thai(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) | |||
| print(pipeline_ins(input=self.thai_sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download_viet(self): | |||
| cache_path = snapshot_download(self.viet_tcrf_model_id) | |||
| tokenizer = NERPreprocessorViet(cache_path) | |||
| model = TransformerCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NamedEntityRecognitionVietPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(f'viet_sentence: {self.viet_sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.viet_sentence)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.viet_sentence)}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_from_modelhub_viet(self): | |||
| model = Model.from_pretrained(self.viet_tcrf_model_id) | |||
| tokenizer = NERPreprocessorViet(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.viet_sentence)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_name_viet(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.viet_tcrf_model_id) | |||
| print(pipeline_ins(input=self.viet_sentence)) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import TransformerCRFForWordSegmentation | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import WordSegmentationThaiPipeline | |||
| from modelscope.preprocessors import WordSegmentationPreprocessorThai | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||
| from modelscope.utils.test_utils import test_level | |||
| class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.word_segmentation | |||
| self.model_id = 'damo/nlp_xlmr_word-segmentation_thai' | |||
| sentence = 'รถคันเก่าก็ยังเก็บเอาไว้ยังไม่ได้ขาย' | |||
| regress_tool = MsRegressTool(baseline=False) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = WordSegmentationPreprocessorThai(cache_path) | |||
| model = TransformerCRFForWordSegmentation.from_pretrained(cache_path) | |||
| pipeline1 = WordSegmentationThaiPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = WordSegmentationPreprocessorThai(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.word_segmentation, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||