diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index c5067c39..7944d1ed 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -70,8 +70,10 @@ class Models(object): space_T_en = 'space-T-en' space_T_cn = 'space-T-cn' tcrf = 'transformer-crf' + tcrf_wseg = 'transformer-crf-for-word-segmentation' transformer_softmax = 'transformer-softmax' lcrf = 'lstm-crf' + lcrf_wseg = 'lstm-crf-for-word-segmentation' gcnncrf = 'gcnn-crf' bart = 'bart' gpt3 = 'gpt3' @@ -219,8 +221,12 @@ class Pipelines(object): domain_classification = 'domain-classification' sentence_similarity = 'sentence-similarity' word_segmentation = 'word-segmentation' + multilingual_word_segmentation = 'multilingual-word-segmentation' + word_segmentation_thai = 'word-segmentation-thai' part_of_speech = 'part-of-speech' named_entity_recognition = 'named-entity-recognition' + named_entity_recognition_thai = 'named-entity-recognition-thai' + named_entity_recognition_viet = 'named-entity-recognition-viet' text_generation = 'text-generation' text2text_generation = 'text2text-generation' sentiment_analysis = 'sentiment-analysis' @@ -343,6 +349,8 @@ class Preprocessors(object): text2text_translate_preprocessor = 'text2text-translate-preprocessor' token_cls_tokenizer = 'token-cls-tokenizer' ner_tokenizer = 'ner-tokenizer' + thai_ner_tokenizer = 'thai-ner-tokenizer' + viet_ner_tokenizer = 'viet-ner-tokenizer' nli_tokenizer = 'nli-tokenizer' sen_cls_tokenizer = 'sen-cls-tokenizer' dialog_intent_preprocessor = 'dialog-intent-preprocessor' @@ -355,6 +363,7 @@ class Preprocessors(object): text_ranking = 'text-ranking' sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' + thai_wseg_tokenizer = 'thai-wseg-tokenizer' fill_mask = 'fill-mask' fill_mask_ponet = 'fill-mask-ponet' faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 5ae93caa..d4562f10 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -5,14 +5,26 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .bart import BartForTextErrorCorrection + from .bert import ( + BertForMaskedLM, + BertForTextRanking, + BertForSentenceEmbedding, + BertForSequenceClassification, + BertForTokenClassification, + BertForDocumentSegmentation, + BertModel, + BertConfig, + ) from .csanmt import CsanmtForTranslation - from .heads import SequenceClassificationHead + from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .gpt_neo import GPTNeoModel from .gpt3 import GPT3ForTextGeneration + from .heads import SequenceClassificationHead from .palm_v2 import PalmForTextGeneration - from .space_T_en import StarForTextToSql - from .space_T_cn import TableQuestionAnswering - from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST from .ponet import PoNetForMaskedLM, PoNetModel, PoNetConfig + from .space import SpaceForDialogIntent, SpaceForDialogModeling, SpaceForDST + from .space_T_cn import TableQuestionAnswering + from .space_T_en import StarForTextToSql from .structbert import ( SbertForFaqQuestionAnswering, SbertForMaskedLM, @@ -22,19 +34,7 @@ if TYPE_CHECKING: SbertModel, SbertTokenizerFast, ) - from .bert import ( - BertForMaskedLM, - BertForTextRanking, - BertForSentenceEmbedding, - BertForSequenceClassification, - BertForTokenClassification, - BertForDocumentSegmentation, - BertModel, - BertConfig, - ) - from .veco import VecoModel, VecoConfig, VecoForTokenClassification, \ - VecoForSequenceClassification, VecoForMaskedLM, VecoTokenizer, VecoTokenizerFast - from .deberta_v2 import DebertaV2ForMaskedLM, DebertaV2Model + from .T5 import T5ForConditionalGeneration from .task_models import ( FeatureExtractionModel, InformationExtractionModel, @@ -45,9 +45,11 @@ if TYPE_CHECKING: TokenClassificationModel, TransformerCRFForNamedEntityRecognition, ) + from .veco import (VecoConfig, VecoForMaskedLM, + VecoForSequenceClassification, + VecoForTokenClassification, VecoModel, VecoTokenizer, + VecoTokenizerFast) - from .T5 import T5ForConditionalGeneration - from .gpt_neo import GPTNeoModel else: _import_structure = { 'backbones': ['SbertModel'], @@ -65,9 +67,13 @@ else: 'SbertModel', ], 'veco': [ - 'VecoModel', 'VecoConfig', 'VecoForTokenClassification', - 'VecoForSequenceClassification', 'VecoForMaskedLM', - 'VecoTokenizer', 'VecoTokenizerFast' + 'VecoConfig', + 'VecoForMaskedLM', + 'VecoForSequenceClassification', + 'VecoForTokenClassification', + 'VecoModel', + 'VecoTokenizer', + 'VecoTokenizerFast', ], 'bert': [ 'BertForMaskedLM', @@ -90,11 +96,13 @@ else: 'FeatureExtractionModel', 'InformationExtractionModel', 'LSTMCRFForNamedEntityRecognition', + 'LSTMCRFForWordSegmentation', 'SequenceClassificationModel', 'SingleBackboneTaskModelBase', 'TaskModelForTextGeneration', 'TokenClassificationModel', 'TransformerCRFForNamedEntityRecognition', + 'TransformerCRFForWordSegmentation', ], 'sentence_embedding': ['SentenceEmbedding'], 'T5': ['T5ForConditionalGeneration'], diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index e733efe2..b8722a36 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -8,8 +8,13 @@ if TYPE_CHECKING: from .feature_extraction import FeatureExtractionModel from .fill_mask import FillMaskModel from .nncrf_for_named_entity_recognition import ( + LSTMCRFForNamedEntityRecognition, TransformerCRFForNamedEntityRecognition, - LSTMCRFForNamedEntityRecognition) + ) + from .nncrf_for_word_segmentation import ( + LSTMCRFForWordSegmentation, + TransformerCRFForWordSegmentation, + ) from .sequence_classification import SequenceClassificationModel from .task_model import SingleBackboneTaskModelBase from .token_classification import TokenClassificationModel @@ -24,6 +29,8 @@ else: 'TransformerCRFForNamedEntityRecognition', 'LSTMCRFForNamedEntityRecognition' ], + 'nncrf_for_word_segmentation': + ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], 'sequence_classification': ['SequenceClassificationModel'], 'task_model': ['SingleBackboneTaskModelBase'], 'token_classification': ['TokenClassificationModel'], diff --git a/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py new file mode 100644 index 00000000..2a3f6cf4 --- /dev/null +++ b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py @@ -0,0 +1,639 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. +# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) +# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. + +import os +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import TokenClassifierWithPredictionsOutput +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] + + +class SequenceLabelingForWordSegmentation(TorchModel): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.model = self.init_model(model_dir, *args, **kwargs) + + model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + self.model.load_state_dict( + torch.load(model_ckpt, map_location=torch.device('cpu'))) + + def init_model(self, model_dir, *args, **kwargs): + raise NotImplementedError + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input_tensor = { + 'input_ids': input['input_ids'], + 'attention_mask': input['attention_mask'], + 'label_mask': input['label_mask'], + } + output = { + 'offset_mapping': input['offset_mapping'], + **input_tensor, + **self.model(input_tensor) + } + return output + + def postprocess(self, input: Dict[str, Any], **kwargs): + predicts = self.model.decode(input) + offset_len = len(input['offset_mapping']) + predictions = torch.narrow( + predicts, 1, 0, + offset_len) # index_select only move loc, not resize + return TokenClassifierWithPredictionsOutput( + loss=None, + logits=None, + hidden_states=None, + attentions=None, + offset_mapping=input['offset_mapping'], + predictions=predictions, + ) + + +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) +class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): + """This model wraps the TransformerCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + model = TransformerCRF(model_dir, num_labels) + return model + + +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) +class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): + """This model wraps the LSTMCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + vocab_size = self.config.vocab_size + embed_width = self.config.embed_width + num_labels = self.config.num_labels + lstm_hidden_size = self.config.lstm_hidden_size + + model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) + return model + + +class TransformerCRF(nn.Module): + """A transformer based model to NER tasks. + + This model will use transformers' backbones as its backbone. + """ + + def __init__(self, model_dir, num_labels, **kwargs): + super(TransformerCRF, self).__init__() + + self.encoder = AutoModel.from_pretrained(model_dir) + self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embed = self.encoder( + inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] + logits = self.linear(embed) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + + return predicts + + +class LSTMCRF(nn.Module): + """ + A standard bilstm-crf model for fast prediction. + """ + + def __init__(self, + vocab_size, + embed_width, + num_labels, + lstm_hidden_size=100, + **kwargs): + super(LSTMCRF, self).__init__() + self.embedding = Embedding(vocab_size, embed_width) + self.lstm = nn.LSTM( + embed_width, + lstm_hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embedding = self.embedding(inputs['input_ids']) + lstm_output, _ = self.lstm(embedding) + logits = self.ffn(lstm_output) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + outputs = {'predicts': predicts} + return outputs + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + + """ + + def __init__(self, num_tags: int, batch_first: bool = False) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: torch.Tensor, + tags: torch.LongTensor, + mask: Optional[torch.ByteTensor] = None, + reduction: str = 'mean') -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode(self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + nbest (`int`): Number of most probable paths for each sequence + pad_tag (`int`): Tag at padded positions. Often input varies in length and + the length will be padded to the maximum length in the batch. Tags at + the padded positions will be assigned with a padding tag, i.e. `pad_tag` + Returns: + A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + """ + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], + dtype=torch.uint8, + device=emissions.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + if nbest == 1: + return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) + return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + + def _validate(self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None) -> None: + if emissions.dim() != 3: + raise ValueError( + f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}') + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' + ) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' + ) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions: torch.Tensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + pad_tag: Optional[int] = None) -> List[List[int]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros((seq_length, batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size), + pad_tag, + dtype=torch.long, + device=device) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(-1), next_score, score) + indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) + history_idx[i - 1] = indices + + # End transition score + # shape: (batch_size, num_tags) + end_score = score + self.end_transitions + _, end_tag = end_score.max(dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), + end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size), + dtype=torch.long, + device=device) + best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx], 1, best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size) + + return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros( + (seq_length, batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size, nbest), + pad_tag, + dtype=torch.long, + device=device) + + # + score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # + history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = broadcast_score + self.transitions.unsqueeze( + 1) + broadcast_emission + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, + self.num_tags).topk( + nbest, dim=1) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), + next_score, score) + indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, + oor_idx) + history_idx[i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size, nbest), + dtype=torch.long, + device=device) + best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ + .view(1, -1).expand(batch_size, -1) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, + best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest + + return torch.where(mask.unsqueeze(-1), best_tags_arr, + oor_tag).permute(2, 1, 0) + + +class Embedding(nn.Module): + + def __init__(self, vocab_size, embed_width): + super(Embedding, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_width) + + def forward(self, input_ids): + return self.embedding(input_ids) diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 73bd0d8c..7b726308 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -16,7 +16,9 @@ if TYPE_CHECKING: from .feature_extraction_pipeline import FeatureExtractionPipeline from .fill_mask_pipeline import FillMaskPipeline from .information_extraction_pipeline import InformationExtractionPipeline - from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline + from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline, \ + NamedEntityRecognitionThaiPipeline, \ + NamedEntityRecognitionVietPipeline from .text_ranking_pipeline import TextRankingPipeline from .sentence_embedding_pipeline import SentenceEmbeddingPipeline from .text_classification_pipeline import TextClassificationPipeline @@ -29,6 +31,8 @@ if TYPE_CHECKING: from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline + from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ + WordSegmentationThaiPipeline else: _import_structure = { @@ -46,8 +50,11 @@ else: 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], 'fill_mask_pipeline': ['FillMaskPipeline'], 'information_extraction_pipeline': ['InformationExtractionPipeline'], - 'named_entity_recognition_pipeline': - ['NamedEntityRecognitionPipeline'], + 'named_entity_recognition_pipeline': [ + 'NamedEntityRecognitionPipeline', + 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' + ], 'text_ranking_pipeline': ['TextRankingPipeline'], 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], 'summarization_pipeline': ['SummarizationPipeline'], @@ -64,6 +71,10 @@ else: 'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], + 'multilingual_word_segmentation_pipeline': [ + 'MultilingualWordSegmentationPipeline', + 'WordSegmentationThaiPipeline' + ], } import sys diff --git a/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py new file mode 100644 index 00000000..56c3a041 --- /dev/null +++ b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py @@ -0,0 +1,125 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TokenClassificationPreprocessor, + WordSegmentationPreprocessorThai) +from modelscope.utils.constant import Tasks + +__all__ = [ + 'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' +] + + +@PIPELINES.register_module( + Tasks.word_segmentation, + module_name=Pipelines.multilingual_word_segmentation) +class MultilingualWordSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction + + Args: + model (str or Model): Supply either a local model dir which supported word segmentation task, or a + model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. + + To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. + """ + + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + assert len(self.config.id2label) > 0 + self.id2label = self.config.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **super().forward(inputs, **forward_params), OutputKeys.TEXT: + text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + text = inputs['text'] + offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] + labels = [ + self.id2label[x] + for x in inputs['predictions'].squeeze(0).cpu().numpy() + ] + entities = [] + entity = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if entity: + entity['end'] = offsets[1] + if label[0] in 'ES': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = {} + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + + word_segments = [entity['span'] for entity in entities] + outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} + + return outputs + + +@PIPELINES.register_module( + Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) +class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = WordSegmentationPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + outputs = super().postprocess(inputs, **postprocess_params) + word_segments = outputs[OutputKeys.OUTPUT] + word_segments = [seg.replace(' ', '') for seg in word_segments] + + return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index 8d8c4542..fdcf9e0f 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -9,13 +9,17 @@ from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, +from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, + Preprocessor, TokenClassificationPreprocessor) from modelscope.utils.constant import Tasks from modelscope.utils.tensor_utils import (torch_nested_detach, torch_nested_numpify) -__all__ = ['NamedEntityRecognitionPipeline'] +__all__ = [ + 'NamedEntityRecognitionPipeline', 'NamedEntityRecognitionThaiPipeline', + 'NamedEntityRecognitionVietPipeline' +] @PIPELINES.register_module( @@ -126,3 +130,39 @@ class NamedEntityRecognitionPipeline(Pipeline): else: outputs = {OutputKeys.OUTPUT: chunks} return outputs + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_thai) +class NamedEntityRecognitionThaiPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition_viet) +class NamedEntityRecognitionVietPipeline(NamedEntityRecognitionPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessorViet( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 76c6d877..e568098f 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -28,7 +28,8 @@ if TYPE_CHECKING: SentencePiecePreprocessor, DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, - TableQuestionAnsweringPreprocessor) + TableQuestionAnsweringPreprocessor, NERPreprocessorViet, + NERPreprocessorThai, WordSegmentationPreprocessorThai) from .video import ReadVideoData, MovieSceneSegmentationPreprocessor else: @@ -58,6 +59,8 @@ else: 'WordSegmentationBlankSetToLabelPreprocessor', 'ZeroShotClassificationPreprocessor', 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', + 'NERPreprocessorViet', 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', 'DialogStateTrackingPreprocessor', 'ConversationalTextToSqlPreprocessor', diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index ea7b6bf4..d9c55fe1 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from .text2text_generation_preprocessor import Text2TextGenerationPreprocessor from .token_classification_preprocessor import TokenClassificationPreprocessor, \ WordSegmentationBlankSetToLabelPreprocessor + from .token_classification_thai_preprocessor import WordSegmentationPreprocessorThai, NERPreprocessorThai + from .token_classification_viet_preprocessor import NERPreprocessorViet from .zero_shot_classification_reprocessor import ZeroShotClassificationPreprocessor from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, @@ -60,10 +62,20 @@ else: 'text_error_correction': [ 'TextErrorCorrectionPreprocessor', ], + 'token_classification_thai_preprocessor': [ + 'NERPreprocessorThai', + 'WordSegmentationPreprocessorThai', + ], + 'token_classification_viet_preprocessor': [ + 'NERPreprocessorViet', + ], 'space': [ - 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', - 'DialogStateTrackingPreprocessor', 'InputFeatures', - 'MultiWOZBPETextField', 'IntentBPETextField' + 'DialogIntentPredictionPreprocessor', + 'DialogModelingPreprocessor', + 'DialogStateTrackingPreprocessor', + 'InputFeatures', + 'MultiWOZBPETextField', + 'IntentBPETextField', ], 'space_T_en': ['ConversationalTextToSqlPreprocessor'], 'space_T_cn': ['TableQuestionAnsweringPreprocessor'], diff --git a/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py new file mode 100644 index 00000000..a356cea7 --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_thai_preprocessor.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_ner_tokenizer) +class NERPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pythainlp import word_tokenize + + segmented_data = ' '.join([ + w.strip(' ') for w in word_tokenize(text=data, engine='newmm') + if w.strip(' ') != '' + ]) + output = super().__call__(segmented_data) + + return output + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.thai_wseg_tokenizer) +class WordSegmentationPreprocessorThai(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + import regex + data = regex.findall(r'\X', data) + data = ' '.join([char for char in data]) + + output = super().__call__(data) + + return output diff --git a/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py new file mode 100644 index 00000000..f8970d1a --- /dev/null +++ b/modelscope/preprocessors/nlp/token_classification_viet_preprocessor.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict, Tuple, Union + +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields, ModeKeys +from modelscope.utils.type_assert import type_assert +from .token_classification_preprocessor import TokenClassificationPreprocessor + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.viet_ner_tokenizer) +class NERPreprocessorViet(TokenClassificationPreprocessor): + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + from pyvi import ViTokenizer + + seg_words = [ + t.strip(' ') for t in ViTokenizer.tokenize(data).split(' ') + if t.strip(' ') != '' + ] + raw_words = [] + for w in seg_words: + raw_words.extend(w.split('_')) + segmented_data = ' '.join(raw_words) + output = super().__call__(segmented_data) + + return output diff --git a/requirements/nlp.txt b/requirements/nlp.txt index a5f3cbd9..9a4abd71 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -4,6 +4,8 @@ megatron_util pai-easynlp # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. protobuf>=3.19.0,<3.21.0 +pythainlp +pyvi # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 diff --git a/tests/pipelines/test_multilingual_named_entity_recognition.py b/tests/pipelines/test_multilingual_named_entity_recognition.py new file mode 100644 index 00000000..6f72c83c --- /dev/null +++ b/tests/pipelines/test_multilingual_named_entity_recognition.py @@ -0,0 +1,102 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition) +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import (NamedEntityRecognitionThaiPipeline, + NamedEntityRecognitionVietPipeline) +from modelscope.preprocessors import NERPreprocessorThai, NERPreprocessorViet +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MultilingualNamedEntityRecognitionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.named_entity_recognition + self.model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + + thai_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_thai-ecommerce-title' + thai_sentence = 'เครื่องชั่งดิจิตอลแบบตั้งพื้น150kg.' + + viet_tcrf_model_id = 'damo/nlp_xlmr_named-entity-recognition_viet-ecommerce-title' + viet_sentence = 'Nón vành dễ thương cho bé gái' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_thai(self): + cache_path = snapshot_download(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionThaiPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'thai_sentence: {self.thai_sentence}\n' + f'pipeline1:{pipeline1(input=self.thai_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.thai_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_thai(self): + model = Model.from_pretrained(self.thai_tcrf_model_id) + tokenizer = NERPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_thai(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.thai_tcrf_model_id) + print(pipeline_ins(input=self.thai_sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_tcrf_by_direct_model_download_viet(self): + cache_path = snapshot_download(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionVietPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'viet_sentence: {self.viet_sentence}\n' + f'pipeline1:{pipeline1(input=self.viet_sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.viet_sentence)}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_tcrf_with_model_from_modelhub_viet(self): + model = Model.from_pretrained(self.viet_tcrf_model_id) + tokenizer = NERPreprocessorViet(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_tcrf_with_model_name_viet(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.viet_tcrf_model_id) + print(pipeline_ins(input=self.viet_sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_multilingual_word_segmentation.py b/tests/pipelines/test_multilingual_word_segmentation.py new file mode 100644 index 00000000..25b4b241 --- /dev/null +++ b/tests/pipelines/test_multilingual_word_segmentation.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TransformerCRFForWordSegmentation +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import WordSegmentationThaiPipeline +from modelscope.preprocessors import WordSegmentationPreprocessorThai +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.test_utils import test_level + + +class WordSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.word_segmentation + self.model_id = 'damo/nlp_xlmr_word-segmentation_thai' + + sentence = 'รถคันเก่าก็ยังเก็บเอาไว้ยังไม่ได้ขาย' + regress_tool = MsRegressTool(baseline=False) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(cache_path) + model = TransformerCRFForWordSegmentation.from_pretrained(cache_path) + pipeline1 = WordSegmentationThaiPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = WordSegmentationPreprocessorThai(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()