From eb3209a79a9dcb0fd6da6bb56b5e29c2db010e14 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Fri, 17 Jun 2022 14:00:31 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E4=B8=AD=E6=96=87=E5=88=86?= =?UTF-8?q?=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit chinese word segmentation Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051491 * add word segmentation * Merge branch 'master' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib * test with model hub * merge with master * update some description and test levels * adding purge logic in test * merge with master * update variables definition * generic word segmentation model as token classification model * add output check --- modelscope/models/nlp/__init__.py | 1 + .../models/nlp/token_classification_model.py | 57 +++++++++++++++ modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 1 + .../nlp/word_segmentation_pipeline.py | 71 +++++++++++++++++++ modelscope/pipelines/outputs.py | 13 ++++ modelscope/preprocessors/nlp.py | 50 ++++++++++++- modelscope/utils/constant.py | 1 + tests/pipelines/test_word_segmentation.py | 62 ++++++++++++++++ 9 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/nlp/token_classification_model.py create mode 100644 modelscope/pipelines/nlp/word_segmentation_pipeline.py create mode 100644 tests/pipelines/test_word_segmentation.py diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index be675c1b..aefcef4a 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -1,3 +1,4 @@ from .sentence_similarity_model import * # noqa F403 from .sequence_classification_model import * # noqa F403 from .text_generation_model import * # noqa F403 +from .token_classification_model import * # noqa F403 diff --git a/modelscope/models/nlp/token_classification_model.py b/modelscope/models/nlp/token_classification_model.py new file mode 100644 index 00000000..43d4aafb --- /dev/null +++ b/modelscope/models/nlp/token_classification_model.py @@ -0,0 +1,57 @@ +import os +from typing import Any, Dict, Union + +import numpy as np +import torch +from sofa import SbertConfig, SbertForTokenClassification + +from modelscope.utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['StructBertForTokenClassification'] + + +@MODELS.register_module( + Tasks.word_segmentation, + module_name=r'structbert-chinese-word-segmentation') +class StructBertForTokenClassification(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the word segmentation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.model = SbertForTokenClassification.from_pretrained( + self.model_dir) + self.config = SbertConfig.from_pretrained(self.model_dir) + + def forward(self, input: Dict[str, + Any]) -> Dict[str, Union[str, np.ndarray]]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, Union[str,np.ndarray]]: results + Example: + { + 'predictions': array([1,4]), # lable 0-negative 1-positive + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + 'text': str(今天), + } + """ + input_ids = torch.tensor(input['input_ids']).unsqueeze(0) + output = self.model(input_ids) + logits = output.logits + pred = torch.argmax(logits[0], dim=-1) + pred = pred.numpy() + + rst = {'predictions': pred, 'logits': logits, 'text': input['text']} + return rst diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 83d1641e..c24a7c3e 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) + Tasks.word_segmentation: + ('structbert-chinese-word-segmentation', + 'damo/nlp_structbert_word-segmentation_chinese-base'), Tasks.sentence_similarity: ('sbert-base-chinese-sentence-similarity', 'damo/nlp_structbert_sentence-similarity_chinese-base'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 1f15a7b8..f1dad0d6 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,3 +1,4 @@ from .sentence_similarity_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 +from .word_segmentation_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py new file mode 100644 index 00000000..49aa112a --- /dev/null +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -0,0 +1,71 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np + +from modelscope.models import Model +from modelscope.models.nlp import StructBertForTokenClassification +from modelscope.preprocessors import TokenClassifcationPreprocessor +from modelscope.utils.constant import Tasks +from ..base import Pipeline, Tensor +from ..builder import PIPELINES + +__all__ = ['WordSegmentationPipeline'] + + +@PIPELINES.register_module( + Tasks.word_segmentation, + module_name=r'structbert-chinese-word-segmentation') +class WordSegmentationPipeline(Pipeline): + + def __init__(self, + model: Union[StructBertForTokenClassification, str], + preprocessor: Optional[TokenClassifcationPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction + + Args: + model (StructBertForTokenClassification): a model instance + preprocessor (TokenClassifcationPreprocessor): a preprocessor instance + """ + model = model if isinstance( + model, + StructBertForTokenClassification) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassifcationPreprocessor(model.model_dir) + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = preprocessor.tokenizer + self.config = model.config + self.id2label = self.config.id2label + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + pred_list = inputs['predictions'] + labels = [] + for pre in pred_list: + labels.append(self.id2label[pre]) + labels = labels[1:-1] + chunks = [] + chunk = '' + assert len(inputs['text']) == len(labels) + for token, label in zip(inputs['text'], labels): + if label[0] == 'B' or label[0] == 'I': + chunk += token + else: + chunk += token + chunks.append(chunk) + chunk = '' + if chunk: + chunks.append(chunk) + seg_result = ' '.join(chunks) + rst = { + 'output': seg_result, + } + return rst diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 1389abd3..c88e358c 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -69,6 +69,19 @@ TASK_OUTPUTS = { # } Tasks.text_generation: ['text'], + # word segmentation result for single sample + # { + # "output": "今天 天气 不错 , 适合 出去 游玩" + # } + Tasks.word_segmentation: ['output'], + + # sentence similarity result for single sample + # { + # "labels": "1", + # "scores": 0.9 + # } + Tasks.sentence_similarity: ['scores', 'labels'], + # ============ audio tasks =================== # ============ multi-modal tasks =================== diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 6773eadf..6a4a25fc 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -12,7 +12,7 @@ from .builder import PREPROCESSORS __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', - 'TextGenerationPreprocessor' + 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor' ] @@ -171,3 +171,51 @@ class TextGenerationPreprocessor(Preprocessor): rst['token_type_ids'].append(feature['token_type_ids']) return {k: torch.tensor(v) for k, v in rst.items()} + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=r'bert-token-classification') +class TokenClassifcationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + # preprocess the data for the model input + + text = data.replace(' ', '').strip() + tokens = [] + for token in text: + token = self.tokenizer.tokenize(token) + tokens.extend(token) + input_ids = self.tokenizer.convert_tokens_to_ids(tokens) + input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) + attention_mask = [1] * len(input_ids) + token_type_ids = [0] * len(input_ids) + return { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'token_type_ids': token_type_ids + } diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6ce835c5..61049734 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -30,6 +30,7 @@ class Tasks(object): image_matting = 'image-matting' # nlp tasks + word_segmentation = 'word-segmentation' sentiment_analysis = 'sentiment-analysis' sentence_similarity = 'sentence-similarity' text_classification = 'text-classification' diff --git a/tests/pipelines/test_word_segmentation.py b/tests/pipelines/test_word_segmentation.py new file mode 100644 index 00000000..4ec2bf29 --- /dev/null +++ b/tests/pipelines/test_word_segmentation.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import StructBertForTokenClassification +from modelscope.pipelines import WordSegmentationPipeline, pipeline +from modelscope.preprocessors import TokenClassifcationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.test_utils import test_level + + +class WordSegmentationTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' + sentence = '今天天气不错,适合出去游玩' + + def setUp(self) -> None: + # switch to False if downloading everytime is not desired + purge_cache = True + if purge_cache: + shutil.rmtree( + get_model_cache_dir(self.model_id), ignore_errors=True) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = TokenClassifcationPreprocessor(cache_path) + model = StructBertForTokenClassification( + cache_path, tokenizer=tokenizer) + pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = TokenClassifcationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.word_segmentation, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.word_segmentation) + print(pipeline_ins(input=self.sentence)) + + +if __name__ == '__main__': + unittest.main()