Browse Source

[to #42322933] add lstm-crf ner model code

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9901220

    * add lstm-crf ner model code
master
xuanjie.wxb yingda.chen 3 years ago
parent
commit
e5c9ded870
6 changed files with 170 additions and 26 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +7
    -3
      modelscope/models/nlp/__init__.py
  3. +103
    -11
      modelscope/models/nlp/nncrf_for_named_entity_recognition.py
  4. +3
    -0
      modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
  5. +13
    -4
      modelscope/preprocessors/nlp.py
  6. +43
    -8
      tests/pipelines/test_named_entity_recognition.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -44,6 +44,7 @@ class Models(object):
space_modeling = 'space-modeling'
star = 'star'
tcrf = 'transformer-crf'
lcrf = 'lstm-crf'
bart = 'bart'
gpt3 = 'gpt3'
bert_for_ds = 'bert-for-document-segmentation'


+ 7
- 3
modelscope/models/nlp/__init__.py View File

@@ -11,7 +11,9 @@ if TYPE_CHECKING:
from .csanmt_for_translation import CsanmtForTranslation
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM,
BertForMaskedLM)
from .nncrf_for_named_entity_recognition import TransformerCRFForNamedEntityRecognition
from .nncrf_for_named_entity_recognition import (
TransformerCRFForNamedEntityRecognition,
LSTMCRFForNamedEntityRecognition)
from .palm_v2 import PalmForTextGeneration
from .token_classification import SbertForTokenClassification
from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification
@@ -34,8 +36,10 @@ else:
'bert_for_document_segmentation': ['BertForDocumentSegmentation'],
'masked_language':
['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'],
'nncrf_for_named_entity_recognition':
['TransformerCRFForNamedEntityRecognition'],
'nncrf_for_named_entity_recognition': [
'TransformerCRFForNamedEntityRecognition',
'LSTMCRFForNamedEntityRecognition'
],
'palm_v2': ['PalmForTextGeneration'],
'token_classification': ['SbertForTokenClassification'],
'sequence_classification':


+ 103
- 11
modelscope/models/nlp/nncrf_for_named_entity_recognition.py View File

@@ -10,27 +10,25 @@ from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['TransformerCRFForNamedEntityRecognition']
__all__ = [
'TransformerCRFForNamedEntityRecognition',
'LSTMCRFForNamedEntityRecognition'
]


@MODELS.register_module(
Tasks.named_entity_recognition, module_name=Models.tcrf)
class TransformerCRFForNamedEntityRecognition(TorchModel):
"""This model wraps the TransformerCRF model to register into model sets.
"""
class SequenceLabelingForNamedEntityRecognition(TorchModel):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)

self.config = AutoConfig.from_pretrained(model_dir)
num_labels = self.config.num_labels

self.model = TransformerCRF(model_dir, num_labels)
self.model = self.init_model(model_dir, *args, **kwargs)

model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE)
self.model.load_state_dict(
torch.load(model_ckpt, map_location=torch.device('cpu')))

def init_model(self, model_dir, *args, **kwargs):
raise NotImplementedError

def train(self):
return self.model.train()

@@ -64,6 +62,39 @@ class TransformerCRFForNamedEntityRecognition(TorchModel):
return output


@MODELS.register_module(
Tasks.named_entity_recognition, module_name=Models.tcrf)
class TransformerCRFForNamedEntityRecognition(
SequenceLabelingForNamedEntityRecognition):
"""This model wraps the TransformerCRF model to register into model sets.
"""

def init_model(self, model_dir, *args, **kwargs):
self.config = AutoConfig.from_pretrained(model_dir)
num_labels = self.config.num_labels

model = TransformerCRF(model_dir, num_labels)
return model


@MODELS.register_module(
Tasks.named_entity_recognition, module_name=Models.lcrf)
class LSTMCRFForNamedEntityRecognition(
SequenceLabelingForNamedEntityRecognition):
"""This model wraps the LSTMCRF model to register into model sets.
"""

def init_model(self, model_dir, *args, **kwargs):
self.config = AutoConfig.from_pretrained(model_dir)
vocab_size = self.config.vocab_size
embed_width = self.config.embed_width
num_labels = self.config.num_labels
lstm_hidden_size = self.config.lstm_hidden_size

model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size)
return model


class TransformerCRF(nn.Module):
"""A transformer based model to NER tasks.

@@ -105,6 +136,56 @@ class TransformerCRF(nn.Module):
return outputs


class LSTMCRF(nn.Module):
"""
A standard bilstm-crf model for fast prediction.
"""

def __init__(self,
vocab_size,
embed_width,
num_labels,
lstm_hidden_size=100,
**kwargs):
super(LSTMCRF, self).__init__()
self.embedding = Embedding(vocab_size, embed_width)
self.lstm = nn.LSTM(
embed_width,
lstm_hidden_size,
num_layers=1,
bidirectional=True,
batch_first=True)
self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels)
self.crf = CRF(num_labels, batch_first=True)

def forward(self, inputs):
embedding = self.embedding(inputs['input_ids'])
lstm_output, _ = self.lstm(embedding)
logits = self.ffn(lstm_output)

if 'label_mask' in inputs:
mask = inputs['label_mask']
masked_lengths = mask.sum(-1).long()
masked_logits = torch.zeros_like(logits)
for i in range(len(mask)):
masked_logits[
i, :masked_lengths[i], :] = logits[i].masked_select(
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)
logits = masked_logits

outputs = {'logits': logits}
return outputs

def decode(self, inputs):
seq_lens = inputs['label_mask'].sum(-1).long()
mask = torch.arange(
inputs['label_mask'].shape[1],
device=seq_lens.device)[None, :] < seq_lens[:, None]
predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0)
outputs = {'predicts': predicts}
return outputs


class CRF(nn.Module):
"""Conditional random field.
This module implements a conditional random field [LMP01]_. The forward computation
@@ -547,3 +628,14 @@ class CRF(nn.Module):

return torch.where(mask.unsqueeze(-1), best_tags_arr,
oor_tag).permute(2, 1, 0)


class Embedding(nn.Module):

def __init__(self, vocab_size, embed_width):
super(Embedding, self).__init__()

self.embedding = nn.Embedding(vocab_size, embed_width)

def forward(self, input_ids):
return self.embedding(input_ids)

+ 3
- 0
modelscope/pipelines/nlp/named_entity_recognition_pipeline.py View File

@@ -84,6 +84,9 @@ class NamedEntityRecognitionPipeline(Pipeline):
entity['span'] = text[entity['start']:entity['end']]
entities.append(entity)
entity = {}
if entity:
entity['span'] = text[entity['start']:entity['end']]
entities.append(entity)
outputs = {OutputKeys.OUTPUT: entities}

return outputs

+ 13
- 4
modelscope/preprocessors/nlp.py View File

@@ -5,8 +5,7 @@ import uuid
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, BertTokenizerFast

from modelscope.metainfo import Models, Preprocessors
from modelscope.outputs import OutputKeys
@@ -539,8 +538,13 @@ class NERPreprocessor(Preprocessor):

self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=True)
self.is_transformer_based_model = 'lstm' not in model_dir
if self.is_transformer_based_model:
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=True)
else:
self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir, use_fast=True)
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)

@@ -604,6 +608,11 @@ class NERPreprocessor(Preprocessor):
else:
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
return {
'text': text,
'input_ids': input_ids,


+ 43
- 8
tests/pipelines/test_named_entity_recognition.py View File

@@ -3,7 +3,8 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import TransformerCRFForNamedEntityRecognition
from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition,
TransformerCRFForNamedEntityRecognition)
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline
from modelscope.preprocessors import NERPreprocessor
@@ -12,12 +13,13 @@ from modelscope.utils.test_utils import test_level


class NamedEntityRecognitionTest(unittest.TestCase):
model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
def test_run_tcrf_by_direct_model_download(self):
cache_path = snapshot_download(self.tcrf_model_id)
tokenizer = NERPreprocessor(cache_path)
model = TransformerCRFForNamedEntityRecognition(
cache_path, tokenizer=tokenizer)
@@ -32,9 +34,36 @@ class NamedEntityRecognitionTest(unittest.TestCase):
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_by_direct_model_download(self):
cache_path = snapshot_download(self.lcrf_model_id)
tokenizer = NERPreprocessor(cache_path)
model = LSTMCRFForNamedEntityRecognition(
cache_path, tokenizer=tokenizer)
pipeline1 = NamedEntityRecognitionPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.named_entity_recognition,
model=model,
preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
def test_run_tcrf_with_model_from_modelhub(self):
model = Model.from_pretrained(self.tcrf_model_id)
tokenizer = NERPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_with_model_from_modelhub(self):
model = Model.from_pretrained(self.lcrf_model_id)
tokenizer = NERPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
@@ -43,9 +72,15 @@ class NamedEntityRecognitionTest(unittest.TestCase):
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
def test_run_tcrf_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.tcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.model_id)
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')


Loading…
Cancel
Save