diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f1179be8..4bb0857b 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -44,6 +44,7 @@ class Models(object): space_modeling = 'space-modeling' star = 'star' tcrf = 'transformer-crf' + lcrf = 'lstm-crf' bart = 'bart' gpt3 = 'gpt3' bert_for_ds = 'bert-for-document-segmentation' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 8bf06c1d..90a37cea 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -11,7 +11,9 @@ if TYPE_CHECKING: from .csanmt_for_translation import CsanmtForTranslation from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, BertForMaskedLM) - from .nncrf_for_named_entity_recognition import TransformerCRFForNamedEntityRecognition + from .nncrf_for_named_entity_recognition import ( + TransformerCRFForNamedEntityRecognition, + LSTMCRFForNamedEntityRecognition) from .palm_v2 import PalmForTextGeneration from .token_classification import SbertForTokenClassification from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification @@ -34,8 +36,10 @@ else: 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], 'masked_language': ['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'], - 'nncrf_for_named_entity_recognition': - ['TransformerCRFForNamedEntityRecognition'], + 'nncrf_for_named_entity_recognition': [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' + ], 'palm_v2': ['PalmForTextGeneration'], 'token_classification': ['SbertForTokenClassification'], 'sequence_classification': diff --git a/modelscope/models/nlp/nncrf_for_named_entity_recognition.py b/modelscope/models/nlp/nncrf_for_named_entity_recognition.py index 2015997f..37216510 100644 --- a/modelscope/models/nlp/nncrf_for_named_entity_recognition.py +++ b/modelscope/models/nlp/nncrf_for_named_entity_recognition.py @@ -10,27 +10,25 @@ from modelscope.models import TorchModel from modelscope.models.builder import MODELS from modelscope.utils.constant import ModelFile, Tasks -__all__ = ['TransformerCRFForNamedEntityRecognition'] +__all__ = [ + 'TransformerCRFForNamedEntityRecognition', + 'LSTMCRFForNamedEntityRecognition' +] -@MODELS.register_module( - Tasks.named_entity_recognition, module_name=Models.tcrf) -class TransformerCRFForNamedEntityRecognition(TorchModel): - """This model wraps the TransformerCRF model to register into model sets. - """ +class SequenceLabelingForNamedEntityRecognition(TorchModel): def __init__(self, model_dir, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) - - self.config = AutoConfig.from_pretrained(model_dir) - num_labels = self.config.num_labels - - self.model = TransformerCRF(model_dir, num_labels) + self.model = self.init_model(model_dir, *args, **kwargs) model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) self.model.load_state_dict( torch.load(model_ckpt, map_location=torch.device('cpu'))) + def init_model(self, model_dir, *args, **kwargs): + raise NotImplementedError + def train(self): return self.model.train() @@ -64,6 +62,39 @@ class TransformerCRFForNamedEntityRecognition(TorchModel): return output +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.tcrf) +class TransformerCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the TransformerCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + num_labels = self.config.num_labels + + model = TransformerCRF(model_dir, num_labels) + return model + + +@MODELS.register_module( + Tasks.named_entity_recognition, module_name=Models.lcrf) +class LSTMCRFForNamedEntityRecognition( + SequenceLabelingForNamedEntityRecognition): + """This model wraps the LSTMCRF model to register into model sets. + """ + + def init_model(self, model_dir, *args, **kwargs): + self.config = AutoConfig.from_pretrained(model_dir) + vocab_size = self.config.vocab_size + embed_width = self.config.embed_width + num_labels = self.config.num_labels + lstm_hidden_size = self.config.lstm_hidden_size + + model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) + return model + + class TransformerCRF(nn.Module): """A transformer based model to NER tasks. @@ -105,6 +136,56 @@ class TransformerCRF(nn.Module): return outputs +class LSTMCRF(nn.Module): + """ + A standard bilstm-crf model for fast prediction. + """ + + def __init__(self, + vocab_size, + embed_width, + num_labels, + lstm_hidden_size=100, + **kwargs): + super(LSTMCRF, self).__init__() + self.embedding = Embedding(vocab_size, embed_width) + self.lstm = nn.LSTM( + embed_width, + lstm_hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) + self.crf = CRF(num_labels, batch_first=True) + + def forward(self, inputs): + embedding = self.embedding(inputs['input_ids']) + lstm_output, _ = self.lstm(embedding) + logits = self.ffn(lstm_output) + + if 'label_mask' in inputs: + mask = inputs['label_mask'] + masked_lengths = mask.sum(-1).long() + masked_logits = torch.zeros_like(logits) + for i in range(len(mask)): + masked_logits[ + i, :masked_lengths[i], :] = logits[i].masked_select( + mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) + logits = masked_logits + + outputs = {'logits': logits} + return outputs + + def decode(self, inputs): + seq_lens = inputs['label_mask'].sum(-1).long() + mask = torch.arange( + inputs['label_mask'].shape[1], + device=seq_lens.device)[None, :] < seq_lens[:, None] + predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) + outputs = {'predicts': predicts} + return outputs + + class CRF(nn.Module): """Conditional random field. This module implements a conditional random field [LMP01]_. The forward computation @@ -547,3 +628,14 @@ class CRF(nn.Module): return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0) + + +class Embedding(nn.Module): + + def __init__(self, vocab_size, embed_width): + super(Embedding, self).__init__() + + self.embedding = nn.Embedding(vocab_size, embed_width) + + def forward(self, input_ids): + return self.embedding(input_ids) diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py index b0b06c88..8fbdde86 100644 --- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py +++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py @@ -84,6 +84,9 @@ class NamedEntityRecognitionPipeline(Pipeline): entity['span'] = text[entity['start']:entity['end']] entities.append(entity) entity = {} + if entity: + entity['span'] = text[entity['start']:entity['end']] + entities.append(entity) outputs = {OutputKeys.OUTPUT: entities} return outputs diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 345d3711..578bbd49 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -5,8 +5,7 @@ import uuid from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np -import torch -from transformers import AutoTokenizer +from transformers import AutoTokenizer, BertTokenizerFast from modelscope.metainfo import Models, Preprocessors from modelscope.outputs import OutputKeys @@ -539,8 +538,13 @@ class NERPreprocessor(Preprocessor): self.model_dir: str = model_dir self.sequence_length = kwargs.pop('sequence_length', 512) - self.tokenizer = AutoTokenizer.from_pretrained( - model_dir, use_fast=True) + self.is_transformer_based_model = 'lstm' not in model_dir + if self.is_transformer_based_model: + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=True) + else: + self.tokenizer = BertTokenizerFast.from_pretrained( + model_dir, use_fast=True) self.is_split_into_words = self.tokenizer.init_kwargs.get( 'is_split_into_words', False) @@ -604,6 +608,11 @@ class NERPreprocessor(Preprocessor): else: label_mask.append(1) offset_mapping.append(encodings['offset_mapping'][i]) + + if not self.is_transformer_based_model: + input_ids = input_ids[1:-1] + attention_mask = attention_mask[1:-1] + label_mask = label_mask[1:-1] return { 'text': text, 'input_ids': input_ids, diff --git a/tests/pipelines/test_named_entity_recognition.py b/tests/pipelines/test_named_entity_recognition.py index 5ba93f49..ad0fa228 100644 --- a/tests/pipelines/test_named_entity_recognition.py +++ b/tests/pipelines/test_named_entity_recognition.py @@ -3,7 +3,8 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import TransformerCRFForNamedEntityRecognition +from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, + TransformerCRFForNamedEntityRecognition) from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline from modelscope.preprocessors import NERPreprocessor @@ -12,12 +13,13 @@ from modelscope.utils.test_utils import test_level class NamedEntityRecognitionTest(unittest.TestCase): - model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' + lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' sentence = '这与温岭市新河镇的一个神秘的传说有关。' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_by_direct_model_download(self): - cache_path = snapshot_download(self.model_id) + def test_run_tcrf_by_direct_model_download(self): + cache_path = snapshot_download(self.tcrf_model_id) tokenizer = NERPreprocessor(cache_path) model = TransformerCRFForNamedEntityRecognition( cache_path, tokenizer=tokenizer) @@ -32,9 +34,36 @@ class NamedEntityRecognitionTest(unittest.TestCase): print() print(f'pipeline2: {pipeline2(input=self.sentence)}') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_by_direct_model_download(self): + cache_path = snapshot_download(self.lcrf_model_id) + tokenizer = NERPreprocessor(cache_path) + model = LSTMCRFForNamedEntityRecognition( + cache_path, tokenizer=tokenizer) + pipeline1 = NamedEntityRecognitionPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id) + def test_run_tcrf_with_model_from_modelhub(self): + model = Model.from_pretrained(self.tcrf_model_id) + tokenizer = NERPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_model_from_modelhub(self): + model = Model.from_pretrained(self.lcrf_model_id) tokenizer = NERPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.named_entity_recognition, @@ -43,9 +72,15 @@ class NamedEntityRecognitionTest(unittest.TestCase): print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') - def test_run_with_model_name(self): + def test_run_tcrf_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.named_entity_recognition, model=self.tcrf_model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_lcrf_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.named_entity_recognition, model=self.model_id) + task=Tasks.named_entity_recognition, model=self.lcrf_model_id) print(pipeline_ins(input=self.sentence)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')