diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ccd36349..371cfd34 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -71,6 +71,7 @@ class Models(object): space_T_en = 'space-T-en' space_T_cn = 'space-T-cn' tcrf = 'transformer-crf' + token_classification_for_ner = 'token-classification-for-ner' tcrf_wseg = 'transformer-crf-for-word-segmentation' transformer_softmax = 'transformer-softmax' lcrf = 'lstm-crf' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 1d71469a..cfa67700 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -40,11 +40,13 @@ if TYPE_CHECKING: FeatureExtractionModel, InformationExtractionModel, LSTMCRFForNamedEntityRecognition, + LSTMCRFForWordSegmentation, SequenceClassificationModel, SingleBackboneTaskModelBase, TaskModelForTextGeneration, TokenClassificationModel, TransformerCRFForNamedEntityRecognition, + TransformerCRFForWordSegmentation, ) from .veco import (VecoConfig, VecoForMaskedLM, VecoForSequenceClassification, diff --git a/modelscope/models/nlp/heads/token_classification_head.py b/modelscope/models/nlp/heads/token_classification_head.py index 443f93df..443b24e3 100644 --- a/modelscope/models/nlp/heads/token_classification_head.py +++ b/modelscope/models/nlp/heads/token_classification_head.py @@ -14,6 +14,8 @@ from modelscope.utils.constant import Tasks @HEADS.register_module( Tasks.token_classification, module_name=Heads.token_classification) +@HEADS.register_module( + Tasks.named_entity_recognition, module_name=Heads.token_classification) @HEADS.register_module( Tasks.part_of_speech, module_name=Heads.token_classification) class TokenClassificationHead(TorchHead): diff --git a/modelscope/models/nlp/mglm/__init__.py b/modelscope/models/nlp/mglm/__init__.py index 26d1101b..3c96ac4a 100644 --- a/modelscope/models/nlp/mglm/__init__.py +++ b/modelscope/models/nlp/mglm/__init__.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .mglm_for_text_summarization import mGlmForSummarization + from .mglm_for_text_summarization import MGLMForTextSummarization else: _import_structure = { 'mglm_for_text_summarization': ['MGLMForTextSummarization'], diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index b8722a36..8fce78a1 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -9,10 +9,8 @@ if TYPE_CHECKING: from .fill_mask import FillMaskModel from .nncrf_for_named_entity_recognition import ( LSTMCRFForNamedEntityRecognition, - TransformerCRFForNamedEntityRecognition, - ) - from .nncrf_for_word_segmentation import ( LSTMCRFForWordSegmentation, + TransformerCRFForNamedEntityRecognition, TransformerCRFForWordSegmentation, ) from .sequence_classification import SequenceClassificationModel @@ -26,11 +24,11 @@ else: 'feature_extraction': ['FeatureExtractionModel'], 'fill_mask': ['FillMaskModel'], 'nncrf_for_named_entity_recognition': [ + 'LSTMCRFForNamedEntityRecognition', + 'LSTMCRFForWordSegmentation', 'TransformerCRFForNamedEntityRecognition', - 'LSTMCRFForNamedEntityRecognition' + 'TransformerCRFForWordSegmentation', ], - 'nncrf_for_word_segmentation': - ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], 'sequence_classification': ['SequenceClassificationModel'], 'task_model': ['SingleBackboneTaskModelBase'], 'token_classification': ['TokenClassificationModel'], diff --git a/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py index 017e35e5..79ce365d 100644 --- a/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py +++ b/modelscope/models/nlp/task_models/nncrf_for_named_entity_recognition.py @@ -167,6 +167,14 @@ class TransformerCRFForNamedEntityRecognition( return model +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) +class TransformerCRFForWordSegmentation(TransformerCRFForNamedEntityRecognition + ): + """This model wraps the TransformerCRF model to register into model sets. + """ + pass + + @MODELS.register_module( Tasks.named_entity_recognition, module_name=Models.lcrf) class LSTMCRFForNamedEntityRecognition( @@ -185,6 +193,11 @@ class LSTMCRFForNamedEntityRecognition( return model +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) +class LSTMCRFForWordSegmentation(LSTMCRFForNamedEntityRecognition): + pass + + class TransformerCRF(nn.Module): """A transformer based model to NER tasks. diff --git a/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py b/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py deleted file mode 100644 index 2a3f6cf4..00000000 --- a/modelscope/models/nlp/task_models/nncrf_for_word_segmentation.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. -# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) -# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. - -import os -from typing import Any, Dict, List, Optional - -import torch -import torch.nn as nn -from transformers import AutoConfig, AutoModel - -from modelscope.metainfo import Models -from modelscope.models import TorchModel -from modelscope.models.builder import MODELS -from modelscope.outputs import TokenClassifierWithPredictionsOutput -from modelscope.utils.constant import ModelFile, Tasks - -__all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] - - -class SequenceLabelingForWordSegmentation(TorchModel): - - def __init__(self, model_dir, *args, **kwargs): - super().__init__(model_dir, *args, **kwargs) - self.model = self.init_model(model_dir, *args, **kwargs) - - model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) - self.model.load_state_dict( - torch.load(model_ckpt, map_location=torch.device('cpu'))) - - def init_model(self, model_dir, *args, **kwargs): - raise NotImplementedError - - def train(self): - return self.model.train() - - def eval(self): - return self.model.eval() - - def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - input_tensor = { - 'input_ids': input['input_ids'], - 'attention_mask': input['attention_mask'], - 'label_mask': input['label_mask'], - } - output = { - 'offset_mapping': input['offset_mapping'], - **input_tensor, - **self.model(input_tensor) - } - return output - - def postprocess(self, input: Dict[str, Any], **kwargs): - predicts = self.model.decode(input) - offset_len = len(input['offset_mapping']) - predictions = torch.narrow( - predicts, 1, 0, - offset_len) # index_select only move loc, not resize - return TokenClassifierWithPredictionsOutput( - loss=None, - logits=None, - hidden_states=None, - attentions=None, - offset_mapping=input['offset_mapping'], - predictions=predictions, - ) - - -@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) -class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): - """This model wraps the TransformerCRF model to register into model sets. - """ - - def init_model(self, model_dir, *args, **kwargs): - self.config = AutoConfig.from_pretrained(model_dir) - num_labels = self.config.num_labels - - model = TransformerCRF(model_dir, num_labels) - return model - - -@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) -class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): - """This model wraps the LSTMCRF model to register into model sets. - """ - - def init_model(self, model_dir, *args, **kwargs): - self.config = AutoConfig.from_pretrained(model_dir) - vocab_size = self.config.vocab_size - embed_width = self.config.embed_width - num_labels = self.config.num_labels - lstm_hidden_size = self.config.lstm_hidden_size - - model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) - return model - - -class TransformerCRF(nn.Module): - """A transformer based model to NER tasks. - - This model will use transformers' backbones as its backbone. - """ - - def __init__(self, model_dir, num_labels, **kwargs): - super(TransformerCRF, self).__init__() - - self.encoder = AutoModel.from_pretrained(model_dir) - self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) - self.crf = CRF(num_labels, batch_first=True) - - def forward(self, inputs): - embed = self.encoder( - inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] - logits = self.linear(embed) - - if 'label_mask' in inputs: - mask = inputs['label_mask'] - masked_lengths = mask.sum(-1).long() - masked_logits = torch.zeros_like(logits) - for i in range(len(mask)): - masked_logits[ - i, :masked_lengths[i], :] = logits[i].masked_select( - mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) - logits = masked_logits - - outputs = {'logits': logits} - return outputs - - def decode(self, inputs): - seq_lens = inputs['label_mask'].sum(-1).long() - mask = torch.arange( - inputs['label_mask'].shape[1], - device=seq_lens.device)[None, :] < seq_lens[:, None] - predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) - - return predicts - - -class LSTMCRF(nn.Module): - """ - A standard bilstm-crf model for fast prediction. - """ - - def __init__(self, - vocab_size, - embed_width, - num_labels, - lstm_hidden_size=100, - **kwargs): - super(LSTMCRF, self).__init__() - self.embedding = Embedding(vocab_size, embed_width) - self.lstm = nn.LSTM( - embed_width, - lstm_hidden_size, - num_layers=1, - bidirectional=True, - batch_first=True) - self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) - self.crf = CRF(num_labels, batch_first=True) - - def forward(self, inputs): - embedding = self.embedding(inputs['input_ids']) - lstm_output, _ = self.lstm(embedding) - logits = self.ffn(lstm_output) - - if 'label_mask' in inputs: - mask = inputs['label_mask'] - masked_lengths = mask.sum(-1).long() - masked_logits = torch.zeros_like(logits) - for i in range(len(mask)): - masked_logits[ - i, :masked_lengths[i], :] = logits[i].masked_select( - mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) - logits = masked_logits - - outputs = {'logits': logits} - return outputs - - def decode(self, inputs): - seq_lens = inputs['label_mask'].sum(-1).long() - mask = torch.arange( - inputs['label_mask'].shape[1], - device=seq_lens.device)[None, :] < seq_lens[:, None] - predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) - outputs = {'predicts': predicts} - return outputs - - -class CRF(nn.Module): - """Conditional random field. - This module implements a conditional random field [LMP01]_. The forward computation - of this class computes the log likelihood of the given sequence of tags and - emission score tensor. This class also has `~CRF.decode` method which finds - the best tag sequence given an emission score tensor using `Viterbi algorithm`_. - Args: - num_tags: Number of tags. - batch_first: Whether the first dimension corresponds to the size of a minibatch. - Attributes: - start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size - ``(num_tags,)``. - end_transitions (`~torch.nn.Parameter`): End transition score tensor of size - ``(num_tags,)``. - transitions (`~torch.nn.Parameter`): Transition score tensor of size - ``(num_tags, num_tags)``. - .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). - "Conditional random fields: Probabilistic models for segmenting and - labeling sequence data". *Proc. 18th International Conf. on Machine - Learning*. Morgan Kaufmann. pp. 282–289. - .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm - - """ - - def __init__(self, num_tags: int, batch_first: bool = False) -> None: - if num_tags <= 0: - raise ValueError(f'invalid number of tags: {num_tags}') - super().__init__() - self.num_tags = num_tags - self.batch_first = batch_first - self.start_transitions = nn.Parameter(torch.empty(num_tags)) - self.end_transitions = nn.Parameter(torch.empty(num_tags)) - self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) - - self.reset_parameters() - - def reset_parameters(self) -> None: - """Initialize the transition parameters. - The parameters will be initialized randomly from a uniform distribution - between -0.1 and 0.1. - """ - nn.init.uniform_(self.start_transitions, -0.1, 0.1) - nn.init.uniform_(self.end_transitions, -0.1, 0.1) - nn.init.uniform_(self.transitions, -0.1, 0.1) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(num_tags={self.num_tags})' - - def forward(self, - emissions: torch.Tensor, - tags: torch.LongTensor, - mask: Optional[torch.ByteTensor] = None, - reduction: str = 'mean') -> torch.Tensor: - """Compute the conditional log likelihood of a sequence of tags given emission scores. - Args: - emissions (`~torch.Tensor`): Emission score tensor of size - ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length, num_tags)`` otherwise. - tags (`~torch.LongTensor`): Sequence of tags tensor of size - ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length)`` otherwise. - mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` - if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. - reduction: Specifies the reduction to apply to the output: - ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. - ``sum``: the output will be summed over batches. ``mean``: the output will be - averaged over batches. ``token_mean``: the output will be averaged over tokens. - Returns: - `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if - reduction is ``none``, ``()`` otherwise. - """ - if reduction not in ('none', 'sum', 'mean', 'token_mean'): - raise ValueError(f'invalid reduction: {reduction}') - if mask is None: - mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) - if mask.dtype != torch.uint8: - mask = mask.byte() - self._validate(emissions, tags=tags, mask=mask) - - if self.batch_first: - emissions = emissions.transpose(0, 1) - tags = tags.transpose(0, 1) - mask = mask.transpose(0, 1) - - # shape: (batch_size,) - numerator = self._compute_score(emissions, tags, mask) - # shape: (batch_size,) - denominator = self._compute_normalizer(emissions, mask) - # shape: (batch_size,) - llh = numerator - denominator - - if reduction == 'none': - return llh - if reduction == 'sum': - return llh.sum() - if reduction == 'mean': - return llh.mean() - return llh.sum() / mask.float().sum() - - def decode(self, - emissions: torch.Tensor, - mask: Optional[torch.ByteTensor] = None, - nbest: Optional[int] = None, - pad_tag: Optional[int] = None) -> List[List[List[int]]]: - """Find the most likely tag sequence using Viterbi algorithm. - Args: - emissions (`~torch.Tensor`): Emission score tensor of size - ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length, num_tags)`` otherwise. - mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` - if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. - nbest (`int`): Number of most probable paths for each sequence - pad_tag (`int`): Tag at padded positions. Often input varies in length and - the length will be padded to the maximum length in the batch. Tags at - the padded positions will be assigned with a padding tag, i.e. `pad_tag` - Returns: - A PyTorch tensor of the best tag sequence for each batch of shape - (nbest, batch_size, seq_length) - """ - if nbest is None: - nbest = 1 - if mask is None: - mask = torch.ones( - emissions.shape[:2], - dtype=torch.uint8, - device=emissions.device) - if mask.dtype != torch.uint8: - mask = mask.byte() - self._validate(emissions, mask=mask) - - if self.batch_first: - emissions = emissions.transpose(0, 1) - mask = mask.transpose(0, 1) - - if nbest == 1: - return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) - return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) - - def _validate(self, - emissions: torch.Tensor, - tags: Optional[torch.LongTensor] = None, - mask: Optional[torch.ByteTensor] = None) -> None: - if emissions.dim() != 3: - raise ValueError( - f'emissions must have dimension of 3, got {emissions.dim()}') - if emissions.size(2) != self.num_tags: - raise ValueError( - f'expected last dimension of emissions is {self.num_tags}, ' - f'got {emissions.size(2)}') - - if tags is not None: - if emissions.shape[:2] != tags.shape: - raise ValueError( - 'the first two dimensions of emissions and tags must match, ' - f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' - ) - - if mask is not None: - if emissions.shape[:2] != mask.shape: - raise ValueError( - 'the first two dimensions of emissions and mask must match, ' - f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' - ) - no_empty_seq = not self.batch_first and mask[0].all() - no_empty_seq_bf = self.batch_first and mask[:, 0].all() - if not no_empty_seq and not no_empty_seq_bf: - raise ValueError('mask of the first timestep must all be on') - - def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, - mask: torch.ByteTensor) -> torch.Tensor: - # emissions: (seq_length, batch_size, num_tags) - # tags: (seq_length, batch_size) - # mask: (seq_length, batch_size) - seq_length, batch_size = tags.shape - mask = mask.float() - - # Start transition score and first emission - # shape: (batch_size,) - score = self.start_transitions[tags[0]] - score += emissions[0, torch.arange(batch_size), tags[0]] - - for i in range(1, seq_length): - # Transition score to next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += self.transitions[tags[i - 1], tags[i]] * mask[i] - - # Emission score for next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] - - # End transition score - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 - # shape: (batch_size,) - last_tags = tags[seq_ends, torch.arange(batch_size)] - # shape: (batch_size,) - score += self.end_transitions[last_tags] - - return score - - def _compute_normalizer(self, emissions: torch.Tensor, - mask: torch.ByteTensor) -> torch.Tensor: - # emissions: (seq_length, batch_size, num_tags) - # mask: (seq_length, batch_size) - seq_length = emissions.size(0) - - # Start transition score and first emission; score has size of - # (batch_size, num_tags) where for each batch, the j-th column stores - # the score that the first timestep has tag j - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[0] - - for i in range(1, seq_length): - # Broadcast score for every possible next tag - # shape: (batch_size, num_tags, 1) - broadcast_score = score.unsqueeze(2) - - # Broadcast emission score for every possible current tag - # shape: (batch_size, 1, num_tags) - broadcast_emissions = emissions[i].unsqueeze(1) - - # Compute the score tensor of size (batch_size, num_tags, num_tags) where - # for each sample, entry at row i and column j stores the sum of scores of all - # possible tag sequences so far that end with transitioning from tag i to tag j - # and emitting - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emissions - - # Sum over all possible current tags, but we're in score space, so a sum - # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of - # all possible tag sequences so far, that end in tag i - # shape: (batch_size, num_tags) - next_score = torch.logsumexp(next_score, dim=1) - - # Set score to the next score if this timestep is valid (mask == 1) - # shape: (batch_size, num_tags) - score = torch.where(mask[i].unsqueeze(1), next_score, score) - - # End transition score - # shape: (batch_size, num_tags) - score += self.end_transitions - - # Sum (log-sum-exp) over all possible tags - # shape: (batch_size,) - return torch.logsumexp(score, dim=1) - - def _viterbi_decode(self, - emissions: torch.FloatTensor, - mask: torch.ByteTensor, - pad_tag: Optional[int] = None) -> List[List[int]]: - # emissions: (seq_length, batch_size, num_tags) - # mask: (seq_length, batch_size) - # return: (batch_size, seq_length) - if pad_tag is None: - pad_tag = 0 - - device = emissions.device - seq_length, batch_size = mask.shape - - # Start transition and first emission - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[0] - history_idx = torch.zeros((seq_length, batch_size, self.num_tags), - dtype=torch.long, - device=device) - oor_idx = torch.zeros((batch_size, self.num_tags), - dtype=torch.long, - device=device) - oor_tag = torch.full((seq_length, batch_size), - pad_tag, - dtype=torch.long, - device=device) - - # - score is a tensor of size (batch_size, num_tags) where for every batch, - # value at column j stores the score of the best tag sequence so far that ends - # with tag j - # - history_idx saves where the best tags candidate transitioned from; this is used - # when we trace back the best tag sequence - # - oor_idx saves the best tags candidate transitioned from at the positions - # where mask is 0, i.e. out of range (oor) - - # Viterbi algorithm recursive case: we compute the score of the best tag sequence - # for every possible next tag - for i in range(1, seq_length): - # Broadcast viterbi score for every possible next tag - # shape: (batch_size, num_tags, 1) - broadcast_score = score.unsqueeze(2) - - # Broadcast emission score for every possible current tag - # shape: (batch_size, 1, num_tags) - broadcast_emission = emissions[i].unsqueeze(1) - - # Compute the score tensor of size (batch_size, num_tags, num_tags) where - # for each sample, entry at row i and column j stores the score of the best - # tag sequence so far that ends with transitioning from tag i to tag j and emitting - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emission - - # Find the maximum score over all possible current tag - # shape: (batch_size, num_tags) - next_score, indices = next_score.max(dim=1) - - # Set score to the next score if this timestep is valid (mask == 1) - # and save the index that produces the next score - # shape: (batch_size, num_tags) - score = torch.where(mask[i].unsqueeze(-1), next_score, score) - indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) - history_idx[i - 1] = indices - - # End transition score - # shape: (batch_size, num_tags) - end_score = score + self.end_transitions - _, end_tag = end_score.max(dim=1) - - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 - - # insert the best tag at each sequence end (last position with mask == 1) - history_idx = history_idx.transpose(1, 0).contiguous() - history_idx.scatter_( - 1, - seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), - end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) - history_idx = history_idx.transpose(1, 0).contiguous() - - # The most probable path for each sequence - best_tags_arr = torch.zeros((seq_length, batch_size), - dtype=torch.long, - device=device) - best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) - for idx in range(seq_length - 1, -1, -1): - best_tags = torch.gather(history_idx[idx], 1, best_tags) - best_tags_arr[idx] = best_tags.data.view(batch_size) - - return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) - - def _viterbi_decode_nbest( - self, - emissions: torch.FloatTensor, - mask: torch.ByteTensor, - nbest: int, - pad_tag: Optional[int] = None) -> List[List[List[int]]]: - # emissions: (seq_length, batch_size, num_tags) - # mask: (seq_length, batch_size) - # return: (nbest, batch_size, seq_length) - if pad_tag is None: - pad_tag = 0 - - device = emissions.device - seq_length, batch_size = mask.shape - - # Start transition and first emission - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[0] - history_idx = torch.zeros( - (seq_length, batch_size, self.num_tags, nbest), - dtype=torch.long, - device=device) - oor_idx = torch.zeros((batch_size, self.num_tags, nbest), - dtype=torch.long, - device=device) - oor_tag = torch.full((seq_length, batch_size, nbest), - pad_tag, - dtype=torch.long, - device=device) - - # + score is a tensor of size (batch_size, num_tags) where for every batch, - # value at column j stores the score of the best tag sequence so far that ends - # with tag j - # + history_idx saves where the best tags candidate transitioned from; this is used - # when we trace back the best tag sequence - # - oor_idx saves the best tags candidate transitioned from at the positions - # where mask is 0, i.e. out of range (oor) - - # Viterbi algorithm recursive case: we compute the score of the best tag sequence - # for every possible next tag - for i in range(1, seq_length): - if i == 1: - broadcast_score = score.unsqueeze(-1) - broadcast_emission = emissions[i].unsqueeze(1) - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emission - else: - broadcast_score = score.unsqueeze(-1) - broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) - # shape: (batch_size, num_tags, nbest, num_tags) - next_score = broadcast_score + self.transitions.unsqueeze( - 1) + broadcast_emission - - # Find the top `nbest` maximum score over all possible current tag - # shape: (batch_size, nbest, num_tags) - next_score, indices = next_score.view(batch_size, -1, - self.num_tags).topk( - nbest, dim=1) - - if i == 1: - score = score.unsqueeze(-1).expand(-1, -1, nbest) - indices = indices * nbest - - # convert to shape: (batch_size, num_tags, nbest) - next_score = next_score.transpose(2, 1) - indices = indices.transpose(2, 1) - - # Set score to the next score if this timestep is valid (mask == 1) - # and save the index that produces the next score - # shape: (batch_size, num_tags, nbest) - score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), - next_score, score) - indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, - oor_idx) - history_idx[i - 1] = indices - - # End transition score shape: (batch_size, num_tags, nbest) - end_score = score + self.end_transitions.unsqueeze(-1) - _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) - - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 - - # insert the best tag at each sequence end (last position with mask == 1) - history_idx = history_idx.transpose(1, 0).contiguous() - history_idx.scatter_( - 1, - seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), - end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) - history_idx = history_idx.transpose(1, 0).contiguous() - - # The most probable path for each sequence - best_tags_arr = torch.zeros((seq_length, batch_size, nbest), - dtype=torch.long, - device=device) - best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ - .view(1, -1).expand(batch_size, -1) - for idx in range(seq_length - 1, -1, -1): - best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, - best_tags) - best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest - - return torch.where(mask.unsqueeze(-1), best_tags_arr, - oor_tag).permute(2, 1, 0) - - -class Embedding(nn.Module): - - def __init__(self, vocab_size, embed_width): - super(Embedding, self).__init__() - - self.embedding = nn.Embedding(vocab_size, embed_width) - - def forward(self, input_ids): - return self.embedding(input_ids) diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py index 8b523baf..982bce32 100644 --- a/modelscope/models/nlp/task_models/token_classification.py +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -4,7 +4,7 @@ from typing import Any, Dict import numpy as np import torch -from modelscope.metainfo import TaskModels +from modelscope.metainfo import Models, TaskModels from modelscope.models.builder import MODELS from modelscope.models.nlp.task_models.task_model import \ SingleBackboneTaskModelBase @@ -21,6 +21,9 @@ __all__ = ['TokenClassificationModel'] Tasks.token_classification, module_name=TaskModels.token_classification) @MODELS.register_module( Tasks.part_of_speech, module_name=TaskModels.token_classification) +@MODELS.register_module( + Tasks.named_entity_recognition, + module_name=Models.token_classification_for_ner) class TokenClassificationModel(SingleBackboneTaskModelBase): def __init__(self, model_dir: str, *args, **kwargs): @@ -59,6 +62,9 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): if labels in input: loss = self.compute_loss(outputs, labels) + # apply label mask to logits + logits = logits[input['label_mask']].unsqueeze(0) + return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 2c6dd85a..377eff6f 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -490,7 +490,10 @@ TASK_OUTPUTS = { # word segmentation result for single sample # { - # "output": "今天 天气 不错 , 适合 出去 游玩" + # "output": ["今天", "天气", "不错", ",", "适合", "出去", "游玩"] + # } + # { + # 'output': ['รถ', 'คัน', 'เก่า', 'ก็', 'ยัง', 'เก็บ', 'เอา'] # } Tasks.word_segmentation: [OutputKeys.OUTPUT], diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 1206ae08..dc79d387 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -29,11 +29,9 @@ if TYPE_CHECKING: from .text2text_generation_pipeline import Text2TextGenerationPipeline from .token_classification_pipeline import TokenClassificationPipeline from .translation_pipeline import TranslationPipeline - from .word_segmentation_pipeline import WordSegmentationPipeline + from .word_segmentation_pipeline import WordSegmentationPipeline, WordSegmentationThaiPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline - from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ - WordSegmentationThaiPipeline else: _import_structure = { @@ -69,14 +67,11 @@ else: 'translation_pipeline': ['TranslationPipeline'], 'translation_quality_estimation_pipeline': ['TranslationQualityEstimationPipeline'], - 'word_segmentation_pipeline': ['WordSegmentationPipeline'], + 'word_segmentation_pipeline': + ['WordSegmentationPipeline', 'WordSegmentationThaiPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], - 'multilingual_word_segmentation_pipeline': [ - 'MultilingualWordSegmentationPipeline', - 'WordSegmentationThaiPipeline' - ], } import sys diff --git a/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py b/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py deleted file mode 100644 index 56c3a041..00000000 --- a/modelscope/pipelines/nlp/multilingual_word_segmentation_pipeline.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -from typing import Any, Dict, Optional, Union - -import torch - -from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.outputs import OutputKeys -from modelscope.pipelines.base import Pipeline -from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, - TokenClassificationPreprocessor, - WordSegmentationPreprocessorThai) -from modelscope.utils.constant import Tasks - -__all__ = [ - 'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' -] - - -@PIPELINES.register_module( - Tasks.word_segmentation, - module_name=Pipelines.multilingual_word_segmentation) -class MultilingualWordSegmentationPipeline(Pipeline): - - def __init__(self, - model: Union[Model, str], - preprocessor: Optional[Preprocessor] = None, - **kwargs): - """Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction - - Args: - model (str or Model): Supply either a local model dir which supported word segmentation task, or a - model id from the model hub, or a torch model instance. - preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for - the model if supplied. - sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. - - To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. - """ - - model = model if isinstance(model, - Model) else Model.from_pretrained(model) - if preprocessor is None: - preprocessor = TokenClassificationPreprocessor( - model.model_dir, - sequence_length=kwargs.pop('sequence_length', 512)) - model.eval() - super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.tokenizer = preprocessor.tokenizer - self.config = model.config - assert len(self.config.id2label) > 0 - self.id2label = self.config.id2label - - def forward(self, inputs: Dict[str, Any], - **forward_params) -> Dict[str, Any]: - text = inputs.pop(OutputKeys.TEXT) - with torch.no_grad(): - return { - **super().forward(inputs, **forward_params), OutputKeys.TEXT: - text - } - - def postprocess(self, inputs: Dict[str, Any], - **postprocess_params) -> Dict[str, str]: - text = inputs['text'] - offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] - labels = [ - self.id2label[x] - for x in inputs['predictions'].squeeze(0).cpu().numpy() - ] - entities = [] - entity = {} - for label, offsets in zip(labels, offset_mapping): - if label[0] in 'BS': - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - entity = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'IES': - if entity: - entity['end'] = offsets[1] - if label[0] in 'ES': - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - entity = {} - if entity: - entity['span'] = text[entity['start']:entity['end']] - entities.append(entity) - - word_segments = [entity['span'] for entity in entities] - outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} - - return outputs - - -@PIPELINES.register_module( - Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) -class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): - - def __init__(self, - model: Union[Model, str], - preprocessor: Optional[Preprocessor] = None, - **kwargs): - model = model if isinstance(model, - Model) else Model.from_pretrained(model) - if preprocessor is None: - preprocessor = WordSegmentationPreprocessorThai( - model.model_dir, - sequence_length=kwargs.pop('sequence_length', 512)) - super().__init__(model=model, preprocessor=preprocessor, **kwargs) - - def postprocess(self, inputs: Dict[str, Any], - **postprocess_params) -> Dict[str, str]: - outputs = super().postprocess(inputs, **postprocess_params) - word_segments = outputs[OutputKeys.OUTPUT] - word_segments = [seg.replace(' ', '') for seg in word_segments] - - return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index 0e35efcb..ece75e1b 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -9,6 +9,7 @@ from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.nlp import TokenClassificationPipeline from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, Preprocessor, TokenClassificationPreprocessor) @@ -25,7 +26,7 @@ __all__ = [ @PIPELINES.register_module( Tasks.named_entity_recognition, module_name=Pipelines.named_entity_recognition) -class NamedEntityRecognitionPipeline(Pipeline): +class NamedEntityRecognitionPipeline(TokenClassificationPipeline): def __init__(self, model: Union[Model, str], @@ -55,97 +56,12 @@ class NamedEntityRecognitionPipeline(Pipeline): if preprocessor is None: preprocessor = TokenClassificationPreprocessor( model.model_dir, - sequence_length=kwargs.pop('sequence_length', 512)) + sequence_length=kwargs.pop('sequence_length', 128)) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) - self.tokenizer = preprocessor.tokenizer - self.config = model.config - assert len(self.config.id2label) > 0 - self.id2label = self.config.id2label - - def forward(self, inputs: Dict[str, Any], - **forward_params) -> Dict[str, Any]: - text = inputs.pop(OutputKeys.TEXT) - with torch.no_grad(): - return { - **self.model(**inputs, **forward_params), OutputKeys.TEXT: text - } - - def postprocess(self, inputs: Dict[str, Any], - **postprocess_params) -> Dict[str, str]: - """process the prediction results - - Args: - inputs (Dict[str, Any]): should be tensors from model - - Returns: - Dict[str, str]: the prediction results - """ - text = inputs['text'] - if OutputKeys.PREDICTIONS not in inputs: - logits = inputs[OutputKeys.LOGITS] - predictions = torch.argmax(logits[0], dim=-1) - else: - predictions = inputs[OutputKeys.PREDICTIONS].squeeze( - 0).cpu().numpy() - predictions = torch_nested_numpify(torch_nested_detach(predictions)) - offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] - - labels = [self.id2label[x] for x in predictions] - if len(labels) > len(offset_mapping): - labels = labels[1:-1] - chunks = [] - chunk = {} - for label, offsets in zip(labels, offset_mapping): - if label[0] in 'BS': - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'I': - if not chunk: - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'E': - if not chunk: - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'IES': - if chunk: - chunk['end'] = offsets[1] - - if label[0] in 'ES': - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - chunk = {} - - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - - # for cws outputs - if len(chunks) > 0 and chunks[0]['type'] == 'cws': - spans = [ - chunk['span'] for chunk in chunks if chunk['span'].strip() - ] - seg_result = ' '.join(spans) - outputs = {OutputKeys.OUTPUT: seg_result} - - # for ner outputs - else: - outputs = {OutputKeys.OUTPUT: chunks} - return outputs + self.id2label = kwargs.get('id2label') + if self.id2label is None and hasattr(self.preprocessor, 'id2label'): + self.id2label = self.preprocessor.id2label @PIPELINES.register_module( diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py index 771660a5..15a318b4 100644 --- a/modelscope/pipelines/nlp/text_classification_pipeline.py +++ b/modelscope/pipelines/nlp/text_classification_pipeline.py @@ -117,7 +117,12 @@ class TextClassificationPipeline(Pipeline): probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() def map_to_label(id): - return self.id2label[id] + if id in self.id2label: + return self.id2label[id] + elif str(id) in self.id2label: + return self.id2label[str(id)] + else: + raise Exception('id not found in id2label') v_func = np.vectorize(map_to_label) return { diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py index d2168b8a..90cf6116 100644 --- a/modelscope/pipelines/nlp/token_classification_pipeline.py +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -64,6 +64,31 @@ class TokenClassificationPipeline(Pipeline): **postprocess_params) -> Dict[str, str]: """process the prediction results + Args: + inputs (Dict[str, Any]): should be tensors from model + + Returns: + Dict[str, str]: the prediction results + """ + chunks = self._chunk_process(inputs, **postprocess_params) + + # for cws outputs + if len(chunks) > 0 and chunks[0]['type'].lower() == 'cws': + spans = [ + chunk['span'] for chunk in chunks if chunk['span'].strip() + ] + seg_result = [span for span in spans] + outputs = {OutputKeys.OUTPUT: seg_result} + + # for ner outputs + else: + outputs = {OutputKeys.OUTPUT: chunks} + return outputs + + def _chunk_process(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results and output as chunks + Args: inputs (Dict[str, Any]): should be tensors from model @@ -71,7 +96,7 @@ class TokenClassificationPipeline(Pipeline): Dict[str, str]: the prediction results """ text = inputs['text'] - if not hasattr(inputs, 'predictions'): + if OutputKeys.PREDICTIONS not in inputs: logits = inputs[OutputKeys.LOGITS] predictions = torch.argmax(logits[0], dim=-1) else: @@ -123,15 +148,4 @@ class TokenClassificationPipeline(Pipeline): chunk['span'] = text[chunk['start']:chunk['end']] chunks.append(chunk) - # for cws outputs - if len(chunks) > 0 and chunks[0]['type'] == 'cws': - spans = [ - chunk['span'] for chunk in chunks if chunk['span'].strip() - ] - seg_result = ' '.join(spans) - outputs = {OutputKeys.OUTPUT: seg_result} - - # for ner outputs - else: - outputs = {OutputKeys.OUTPUT: chunks} - return outputs + return chunks diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 3d6f8a4a..ac1c4789 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -9,18 +9,20 @@ from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.nlp import TokenClassificationPipeline from modelscope.preprocessors import (Preprocessor, - TokenClassificationPreprocessor) + TokenClassificationPreprocessor, + WordSegmentationPreprocessorThai) from modelscope.utils.constant import Tasks from modelscope.utils.tensor_utils import (torch_nested_detach, torch_nested_numpify) -__all__ = ['WordSegmentationPipeline'] +__all__ = ['WordSegmentationPipeline', 'WordSegmentationThaiPipeline'] @PIPELINES.register_module( Tasks.word_segmentation, module_name=Pipelines.word_segmentation) -class WordSegmentationPipeline(Pipeline): +class WordSegmentationPipeline(TokenClassificationPipeline): def __init__(self, model: Union[Model, str], @@ -58,89 +60,38 @@ class WordSegmentationPipeline(Pipeline): self.id2label = kwargs.get('id2label') if self.id2label is None and hasattr(self.preprocessor, 'id2label'): self.id2label = self.preprocessor.id2label - assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ - 'as a parameter or make sure the preprocessor has the attribute.' - def forward(self, inputs: Dict[str, Any], - **forward_params) -> Dict[str, Any]: - text = inputs.pop(OutputKeys.TEXT) - with torch.no_grad(): - return { - **self.model(**inputs, **forward_params), OutputKeys.TEXT: text - } + +@PIPELINES.register_module( + Tasks.word_segmentation, + module_name=Pipelines.multilingual_word_segmentation) +class MultilingualWordSegmentationPipeline(WordSegmentationPipeline): def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]: - """process the prediction results + chunks = self._chunk_process(inputs, **postprocess_params) + word_segments = [entity['span'] for entity in chunks] + return {OutputKeys.OUTPUT: word_segments} - Args: - inputs (Dict[str, Any]): should be tensors from model - Returns: - Dict[str, str]: the prediction results - """ - text = inputs['text'] - if not hasattr(inputs, 'predictions'): - logits = inputs[OutputKeys.LOGITS] - predictions = torch.argmax(logits[0], dim=-1) - else: - predictions = inputs[OutputKeys.PREDICTIONS].squeeze( - 0).cpu().numpy() - predictions = torch_nested_numpify(torch_nested_detach(predictions)) - offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] - - labels = [self.id2label[x] for x in predictions] - if len(labels) > len(offset_mapping): - labels = labels[1:-1] - chunks = [] - chunk = {} - for label, offsets in zip(labels, offset_mapping): - if label[0] in 'BS': - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'I': - if not chunk: - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'E': - if not chunk: - chunk = { - 'type': label[2:], - 'start': offsets[0], - 'end': offsets[1] - } - if label[0] in 'IES': - if chunk: - chunk['end'] = offsets[1] - - if label[0] in 'ES': - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - chunk = {} - - if chunk: - chunk['span'] = text[chunk['start']:chunk['end']] - chunks.append(chunk) - - # for cws outputs - if len(chunks) > 0 and chunks[0]['type'] == 'cws': - spans = [ - chunk['span'] for chunk in chunks if chunk['span'].strip() - ] - seg_result = ' '.join(spans) - outputs = {OutputKeys.OUTPUT: seg_result} - - # for ner outputs - else: - outputs = {OutputKeys.OUTPUT: chunks} - return outputs +@PIPELINES.register_module( + Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) +class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = WordSegmentationPreprocessorThai( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 512)) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + chunks = self._chunk_process(inputs, **postprocess_params) + word_segments = [entity['span'].replace(' ', '') for entity in chunks] + return {OutputKeys.OUTPUT: word_segments} diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 93cc20e2..87a6eaff 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -154,4 +154,6 @@ def parse_label_mapping(model_dir): elif hasattr(config, 'id2label'): id2label = config.id2label label2id = {label: id for id, label in id2label.items()} + if label2id is not None: + label2id = {label: int(id) for label, id in label2id.items()} return label2id diff --git a/tests/pipelines/test_addr_similarity.py b/tests/pipelines/test_addr_similarity.py new file mode 100644 index 00000000..57c47b09 --- /dev/null +++ b/tests/pipelines/test_addr_similarity.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TextClassificationPipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool +from modelscope.utils.test_utils import test_level + + +class AddrSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): + + sentence1 = '阿里巴巴西溪园区' + sentence2 = '文一西路阿里巴巴' + model_id = 'damo/nlp_structbert_address-matching_chinese_base' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = SequenceClassificationPreprocessor(model.model_dir) + + pipeline_ins = pipeline( + task=Tasks.text_classification, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.text_classification, model=self.model_id) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index 0df44f5b..3317c604 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -23,9 +23,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' + addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base' sentence = '这与温岭市新河镇的一个神秘的传说有关。' sentence_en = 'pizza shovel' sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' + addr = '浙江省杭州市余杭区文一西路969号亲橙里' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_tcrf_by_direct_model_download(self): @@ -71,6 +73,23 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): preprocessor=tokenizer) print(pipeline_ins(input=self.sentence)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_addrst_with_model_from_modelhub(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_address-parsing_chinese_base') + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.addr)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_addrst_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.addr_model_id) + print(pipeline_ins(input=self.addr)) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_lcrf_with_model_from_modelhub(self): model = Model.from_pretrained(self.lcrf_model_id)