diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 23e64ffc..28f88cd5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -18,6 +18,7 @@ class Models(object): veco = 'veco' translation = 'csanmt-translation' space = 'space' + tcrf = 'transformer-crf' # audio models sambert_hifigan = 'sambert-hifigan' @@ -56,6 +57,7 @@ class Pipelines(object): # nlp tasks sentence_similarity = 'sentence-similarity' word_segmentation = 'word-segmentation' + named_entity_recognition = 'named-entity-recognition' text_generation = 'text-generation' sentiment_analysis = 'sentiment-analysis' sentiment_classification = 'sentiment-classification' @@ -113,6 +115,7 @@ class Preprocessors(object): bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer' + ner_tokenizer = 'ner-tokenizer' nli_tokenizer = 'nli-tokenizer' sen_cls_tokenizer = 'sen-cls-tokenizer' dialog_intent_preprocessor = 'dialog-intent-preprocessor' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 4f69a189..5a6855e9 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -2,6 +2,7 @@ from modelscope.utils.error import TENSORFLOW_IMPORT_WARNING from .bert_for_sequence_classification import * # noqa F403 from .masked_language import * # noqa F403 +from .nncrf_for_named_entity_recognition import * # noqa F403 from .palm_for_text_generation import * # noqa F403 from .sbert_for_nli import * # noqa F403 from .sbert_for_sentence_similarity import * # noqa F403 diff --git a/modelscope/models/nlp/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/nncrf_for_named_entity_recognition.py new file mode 100644 index 00000000..75e6f15e --- /dev/null +++ b/modelscope/models/nlp/nncrf_for_named_entity_recognition.py @@ -0,0 +1,545 @@ +import os +from typing import Any, Dict, List, Optional + +import json +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable +from transformers import AutoConfig, AutoModel + +from ...metainfo import Models +from ...utils.constant import ModelFile, Tasks +from ..base import Model +from ..builder import MODELS + +__all__ = ['TransformerCRFForNamedEntityRecognition'] + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.tcrf) +class TransformerCRFForNamedEntityRecognition(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + self.model = TransformerCRF(model_dir, num_labels) + + model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) + self.model.load_state_dict(torch.load(model_ckpt)) + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + input_tensor = { + 'input_ids': + torch.tensor(input['input_ids']).unsqueeze(0), + 'attention_mask': + torch.tensor(input['attention_mask']).unsqueeze(0), + 'label_mask': + torch.tensor(input['label_mask'], dtype=torch.bool).unsqueeze(0) + } + output = { + 'text': input['text'], + 'offset_mapping': input['offset_mapping'], + **input_tensor, + **self.model(input_tensor) + } + return output + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + predicts = self.model.decode(input) + output = { + 'text': input['text'], + 'offset_mapping': input['offset_mapping'], + 'predicts': predicts['predicts'].squeeze(0).numpy(), + } + return output + + +class TransformerCRF(nn.Module): + + def __init__(self, model_dir, num_labels, **kwargs): + super(TransformerCRF, self).__init__() + + self.encoder = AutoModel.from_pretrained(model_dir) + self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embed = self.encoder( + inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] + logits = self.linear(embed) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + outputs = {'predicts': predicts} + return outputs + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + + The implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) + and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. + """ + + def __init__(self, num_tags: int, batch_first: bool = False) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: torch.Tensor, + tags: torch.LongTensor, + mask: Optional[torch.ByteTensor] = None, + reduction: str = 'mean') -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode(self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm. + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + nbest (`int`): Number of most probable paths for each sequence + pad_tag (`int`): Tag at padded positions. Often input varies in length and + the length will be padded to the maximum length in the batch. Tags at + the padded positions will be assigned with a padding tag, i.e. `pad_tag` + Returns: + A PyTorch tensor of the best tag sequence for each batch of shape + (nbest, batch_size, seq_length) + """ + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], + dtype=torch.uint8, + device=emissions.device) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + if nbest == 1: + return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) + return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + + def _validate(self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None) -> None: + if emissions.dim() != 3: + raise ValueError( + f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}') + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' + ) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' + ) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions: torch.Tensor, + mask: torch.ByteTensor) -> torch.Tensor: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + pad_tag: Optional[int] = None) -> List[List[int]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros((seq_length, batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size), + pad_tag, + dtype=torch.long, + device=device) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(-1), next_score, score) + indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) + history_idx[i - 1] = indices + + # End transition score + # shape: (batch_size, num_tags) + end_score = score + self.end_transitions + _, end_tag = end_score.max(dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), + end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size), + dtype=torch.long, + device=device) + best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx], 1, best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size) + + return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None) -> List[List[List[int]]]: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history_idx = torch.zeros( + (seq_length, batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_idx = torch.zeros((batch_size, self.num_tags, nbest), + dtype=torch.long, + device=device) + oor_tag = torch.full((seq_length, batch_size, nbest), + pad_tag, + dtype=torch.long, + device=device) + + # + score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # + history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = broadcast_score + self.transitions.unsqueeze( + 1) + broadcast_emission + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, + self.num_tags).topk( + nbest, dim=1) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), + next_score, score) + indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, + oor_idx) + history_idx[i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx = history_idx.transpose(1, 0).contiguous() + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) + history_idx = history_idx.transpose(1, 0).contiguous() + + # The most probable path for each sequence + best_tags_arr = torch.zeros((seq_length, batch_size, nbest), + dtype=torch.long, + device=device) + best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ + .view(1, -1).expand(batch_size, -1) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, + best_tags) + best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest + + return torch.where(mask.unsqueeze(-1), best_tags_arr, + oor_tag).permute(2, 1, 0) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 9116b002..48a84237 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -127,6 +127,15 @@ TASK_OUTPUTS = { # } Tasks.word_segmentation: [OutputKeys.OUTPUT], + # named entity recognition result for single sample + # { + # "output": [ + # {"type": "LOC", "start": 2, "end": 5, "span": "温岭市"}, + # {"type": "LOC", "start": 5, "end": 8, "span": "新河镇"} + # ] + # } + Tasks.named_entity_recognition: [OutputKeys.OUTPUT], + # sentence similarity result for single sample # { # "scores": 0.9 diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 702700eb..cdde302a 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -19,6 +19,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.word_segmentation: (Pipelines.word_segmentation, 'damo/nlp_structbert_word-segmentation_chinese-base'), + Tasks.named_entity_recognition: + (Pipelines.named_entity_recognition, + 'damo/nlp_transformercrf_named-entity-recognition_chinese-base-news'), Tasks.sentence_similarity: (Pipelines.sentence_similarity, 'damo/nlp_structbert_sentence-similarity_chinese-base'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index c109d49d..59e93dee 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -22,6 +22,7 @@ try: from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 from .zero_shot_classification_pipeline import * # noqa F403 + from .named_entity_recognition_pipeline import * # noqa F403 except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py new file mode 100644 index 00000000..744bad2d --- /dev/null +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -0,0 +1,71 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from ...metainfo import Pipelines +from ...models import Model +from ...models.nlp import TransformerCRFForNamedEntityRecognition +from ...outputs import OutputKeys +from ...preprocessors import NERPreprocessor +from ...utils.constant import Tasks +from ..base import Pipeline, Tensor +from ..builder import PIPELINES + +__all__ = ['NamedEntityRecognitionPipeline'] + + +@PIPELINES.register_module( + Tasks.named_entity_recognition, + module_name=Pipelines.named_entity_recognition) +class NamedEntityRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[TransformerCRFForNamedEntityRecognition, str], + preprocessor: Optional[NERPreprocessor] = None, + **kwargs): + + model = model if isinstance(model, + TransformerCRFForNamedEntityRecognition + ) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = NERPreprocessor(model.model_dir) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + assert len(self.config.id2label) > 0 + self.id2label = self.config.id2label + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + text = inputs['text'] + offset_mapping = inputs['offset_mapping'] + labels = [self.id2label[x] for x in inputs['predicts']] + entities = [] + entity = {} + for label, offsets in zip(labels, offset_mapping): + if label[0] in 'BS': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = { + 'type': label[2:], + 'start': offsets[0], + 'end': offsets[1] + } + if label[0] in 'IES': + if entity: + entity['end'] = offsets[1] + if label[0] in 'ES': + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) + entity = {} + outputs = {OutputKeys.OUTPUT: entities} + + return outputs diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 910aed6a..59c69b8f 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -17,8 +17,8 @@ __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', 'NLIPreprocessor', 'SentimentClassificationPreprocessor', - 'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor', - 'ZeroShotClassificationPreprocessor' + 'SentenceSimilarityPreprocessor', 'FillMaskPreprocessor', + 'ZeroShotClassificationPreprocessor', 'NERPreprocessor' ] @@ -370,3 +370,68 @@ class ZeroShotClassificationPreprocessor(NLPPreprocessorBase): return_tensors='pt', truncation_strategy='only_first') return features + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.ner_tokenizer) +class NERPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=True) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = data + encodings = self.tokenizer( + text, + add_special_tokens=True, + padding=True, + truncation=True, + max_length=self.sequence_length, + return_offsets_mapping=True) + input_ids = encodings['input_ids'] + attention_mask = encodings['attention_mask'] + word_ids = encodings.word_ids() + label_mask = [] + offset_mapping = [] + for i in range(len(word_ids)): + if word_ids[i] is None: + label_mask.append(0) + elif word_ids[i] == word_ids[i - 1]: + label_mask.append(0) + offset_mapping[-1] = (offset_mapping[-1][0], + encodings['offset_mapping'][i][1]) + else: + label_mask.append(1) + offset_mapping.append(encodings['offset_mapping'][i]) + + return { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + 'offset_mapping': offset_mapping + } diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index cffe607c..68fdba38 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -34,6 +34,7 @@ class CVTasks(object): class NLPTasks(object): # nlp tasks word_segmentation = 'word-segmentation' + named_entity_recognition = 'named-entity-recognition' nli = 'nli' sentiment_classification = 'sentiment-classification' sentiment_analysis = 'sentiment-analysis' diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py new file mode 100644 index 00000000..eb670501 --- /dev/null +++ b/tests/pipelines/test_named_entity_recognition.py @@ -0,0 +1,57 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TransformerCRFForNamedEntityRecognition +from modelscope.pipelines import NamedEntityRecognitionPipeline, pipeline +from modelscope.preprocessors import NERPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class NamedEntityRecognitionTest(unittest.TestCase): + model_id = 'damo/nlp_transformercrf_named-entity-recognition_chinese-base-news' + sentence = '这与温岭市新河镇的一个神秘的传说有关。' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = NERPreprocessor(cache_path) + model = TransformerCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = NERPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.named_entity_recognition) + print(pipeline_ins(input=self.sentence)) + + +if __name__ == '__main__': + unittest.main()