Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9901220 * add lstm-crf ner model codemaster
| @@ -44,6 +44,7 @@ class Models(object): | |||
| space_modeling = 'space-modeling' | |||
| star = 'star' | |||
| tcrf = 'transformer-crf' | |||
| lcrf = 'lstm-crf' | |||
| bart = 'bart' | |||
| gpt3 = 'gpt3' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| @@ -11,7 +11,9 @@ if TYPE_CHECKING: | |||
| from .csanmt_for_translation import CsanmtForTranslation | |||
| from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||
| BertForMaskedLM) | |||
| from .nncrf_for_named_entity_recognition import TransformerCRFForNamedEntityRecognition | |||
| from .nncrf_for_named_entity_recognition import ( | |||
| TransformerCRFForNamedEntityRecognition, | |||
| LSTMCRFForNamedEntityRecognition) | |||
| from .palm_v2 import PalmForTextGeneration | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | |||
| @@ -34,8 +36,10 @@ else: | |||
| 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | |||
| 'masked_language': | |||
| ['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'], | |||
| 'nncrf_for_named_entity_recognition': | |||
| ['TransformerCRFForNamedEntityRecognition'], | |||
| 'nncrf_for_named_entity_recognition': [ | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForNamedEntityRecognition' | |||
| ], | |||
| 'palm_v2': ['PalmForTextGeneration'], | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'sequence_classification': | |||
| @@ -10,27 +10,25 @@ from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['TransformerCRFForNamedEntityRecognition'] | |||
| __all__ = [ | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForNamedEntityRecognition' | |||
| ] | |||
| @MODELS.register_module( | |||
| Tasks.named_entity_recognition, module_name=Models.tcrf) | |||
| class TransformerCRFForNamedEntityRecognition(TorchModel): | |||
| """This model wraps the TransformerCRF model to register into model sets. | |||
| """ | |||
| class SequenceLabelingForNamedEntityRecognition(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| num_labels = self.config.num_labels | |||
| self.model = TransformerCRF(model_dir, num_labels) | |||
| self.model = self.init_model(model_dir, *args, **kwargs) | |||
| model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
| self.model.load_state_dict( | |||
| torch.load(model_ckpt, map_location=torch.device('cpu'))) | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| raise NotImplementedError | |||
| def train(self): | |||
| return self.model.train() | |||
| @@ -64,6 +62,39 @@ class TransformerCRFForNamedEntityRecognition(TorchModel): | |||
| return output | |||
| @MODELS.register_module( | |||
| Tasks.named_entity_recognition, module_name=Models.tcrf) | |||
| class TransformerCRFForNamedEntityRecognition( | |||
| SequenceLabelingForNamedEntityRecognition): | |||
| """This model wraps the TransformerCRF model to register into model sets. | |||
| """ | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| num_labels = self.config.num_labels | |||
| model = TransformerCRF(model_dir, num_labels) | |||
| return model | |||
| @MODELS.register_module( | |||
| Tasks.named_entity_recognition, module_name=Models.lcrf) | |||
| class LSTMCRFForNamedEntityRecognition( | |||
| SequenceLabelingForNamedEntityRecognition): | |||
| """This model wraps the LSTMCRF model to register into model sets. | |||
| """ | |||
| def init_model(self, model_dir, *args, **kwargs): | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| vocab_size = self.config.vocab_size | |||
| embed_width = self.config.embed_width | |||
| num_labels = self.config.num_labels | |||
| lstm_hidden_size = self.config.lstm_hidden_size | |||
| model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) | |||
| return model | |||
| class TransformerCRF(nn.Module): | |||
| """A transformer based model to NER tasks. | |||
| @@ -105,6 +136,56 @@ class TransformerCRF(nn.Module): | |||
| return outputs | |||
| class LSTMCRF(nn.Module): | |||
| """ | |||
| A standard bilstm-crf model for fast prediction. | |||
| """ | |||
| def __init__(self, | |||
| vocab_size, | |||
| embed_width, | |||
| num_labels, | |||
| lstm_hidden_size=100, | |||
| **kwargs): | |||
| super(LSTMCRF, self).__init__() | |||
| self.embedding = Embedding(vocab_size, embed_width) | |||
| self.lstm = nn.LSTM( | |||
| embed_width, | |||
| lstm_hidden_size, | |||
| num_layers=1, | |||
| bidirectional=True, | |||
| batch_first=True) | |||
| self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) | |||
| self.crf = CRF(num_labels, batch_first=True) | |||
| def forward(self, inputs): | |||
| embedding = self.embedding(inputs['input_ids']) | |||
| lstm_output, _ = self.lstm(embedding) | |||
| logits = self.ffn(lstm_output) | |||
| if 'label_mask' in inputs: | |||
| mask = inputs['label_mask'] | |||
| masked_lengths = mask.sum(-1).long() | |||
| masked_logits = torch.zeros_like(logits) | |||
| for i in range(len(mask)): | |||
| masked_logits[ | |||
| i, :masked_lengths[i], :] = logits[i].masked_select( | |||
| mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) | |||
| logits = masked_logits | |||
| outputs = {'logits': logits} | |||
| return outputs | |||
| def decode(self, inputs): | |||
| seq_lens = inputs['label_mask'].sum(-1).long() | |||
| mask = torch.arange( | |||
| inputs['label_mask'].shape[1], | |||
| device=seq_lens.device)[None, :] < seq_lens[:, None] | |||
| predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) | |||
| outputs = {'predicts': predicts} | |||
| return outputs | |||
| class CRF(nn.Module): | |||
| """Conditional random field. | |||
| This module implements a conditional random field [LMP01]_. The forward computation | |||
| @@ -547,3 +628,14 @@ class CRF(nn.Module): | |||
| return torch.where(mask.unsqueeze(-1), best_tags_arr, | |||
| oor_tag).permute(2, 1, 0) | |||
| class Embedding(nn.Module): | |||
| def __init__(self, vocab_size, embed_width): | |||
| super(Embedding, self).__init__() | |||
| self.embedding = nn.Embedding(vocab_size, embed_width) | |||
| def forward(self, input_ids): | |||
| return self.embedding(input_ids) | |||
| @@ -84,6 +84,9 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||
| entity['span'] = text[entity['start']:entity['end']] | |||
| entities.append(entity) | |||
| entity = {} | |||
| if entity: | |||
| entity['span'] = text[entity['start']:entity['end']] | |||
| entities.append(entity) | |||
| outputs = {OutputKeys.OUTPUT: entities} | |||
| return outputs | |||
| @@ -5,8 +5,7 @@ import uuid | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from transformers import AutoTokenizer | |||
| from transformers import AutoTokenizer, BertTokenizerFast | |||
| from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| @@ -539,8 +538,13 @@ class NERPreprocessor(Preprocessor): | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| self.is_transformer_based_model = 'lstm' not in model_dir | |||
| if self.is_transformer_based_model: | |||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| else: | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| @@ -604,6 +608,11 @@ class NERPreprocessor(Preprocessor): | |||
| else: | |||
| label_mask.append(1) | |||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| return { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| @@ -3,7 +3,8 @@ import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import TransformerCRFForNamedEntityRecognition | |||
| from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, | |||
| TransformerCRFForNamedEntityRecognition) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline | |||
| from modelscope.preprocessors import NERPreprocessor | |||
| @@ -12,12 +13,13 @@ from modelscope.utils.test_utils import test_level | |||
| class NamedEntityRecognitionTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
| lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | |||
| sentence = '这与温岭市新河镇的一个神秘的传说有关。' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| def test_run_tcrf_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.tcrf_model_id) | |||
| tokenizer = NERPreprocessor(cache_path) | |||
| model = TransformerCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| @@ -32,9 +34,36 @@ class NamedEntityRecognitionTest(unittest.TestCase): | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.lcrf_model_id) | |||
| tokenizer = NERPreprocessor(cache_path) | |||
| model = LSTMCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NamedEntityRecognitionPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| def test_run_tcrf_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.tcrf_model_id) | |||
| tokenizer = NERPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.lcrf_model_id) | |||
| tokenizer = NERPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| @@ -43,9 +72,15 @@ class NamedEntityRecognitionTest(unittest.TestCase): | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| def test_run_tcrf_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.tcrf_model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, model=self.model_id) | |||
| task=Tasks.named_entity_recognition, model=self.lcrf_model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||