chinese word segmentation
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051491
* add word segmentation
* Merge branch 'master' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib
* test with model hub
* merge with master
* update some description and test levels
* adding purge logic in test
* merge with master
* update variables definition
* generic word segmentation model as token classification model
* add output check
master
| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_model import * # noqa F403 | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| from .token_classification_model import * # noqa F403 | |||
| @@ -0,0 +1,57 @@ | |||
| import os | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertConfig, SbertForTokenClassification | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['StructBertForTokenClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.word_segmentation, | |||
| module_name=r'structbert-chinese-word-segmentation') | |||
| class StructBertForTokenClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the word segmentation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = SbertForTokenClassification.from_pretrained( | |||
| self.model_dir) | |||
| self.config = SbertConfig.from_pretrained(self.model_dir) | |||
| def forward(self, input: Dict[str, | |||
| Any]) -> Dict[str, Union[str, np.ndarray]]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Union[str,np.ndarray]]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1,4]), # lable 0-negative 1-positive | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| 'text': str(今天), | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids']).unsqueeze(0) | |||
| output = self.model(input_ids) | |||
| logits = output.logits | |||
| pred = torch.argmax(logits[0], dim=-1) | |||
| pred = pred.numpy() | |||
| rst = {'predictions': pred, 'logits': logits, 'text': input['text']} | |||
| return rst | |||
| @@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines') | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.word_segmentation: | |||
| ('structbert-chinese-word-segmentation', | |||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
| Tasks.sentence_similarity: | |||
| ('sbert-base-chinese-sentence-similarity', | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .word_segmentation_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,71 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import numpy as np | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import StructBertForTokenClassification | |||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| __all__ = ['WordSegmentationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.word_segmentation, | |||
| module_name=r'structbert-chinese-word-segmentation') | |||
| class WordSegmentationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[StructBertForTokenClassification, str], | |||
| preprocessor: Optional[TokenClassifcationPreprocessor] = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | |||
| Args: | |||
| model (StructBertForTokenClassification): a model instance | |||
| preprocessor (TokenClassifcationPreprocessor): a preprocessor instance | |||
| """ | |||
| model = model if isinstance( | |||
| model, | |||
| StructBertForTokenClassification) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = TokenClassifcationPreprocessor(model.model_dir) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.config = model.config | |||
| self.id2label = self.config.id2label | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| pred_list = inputs['predictions'] | |||
| labels = [] | |||
| for pre in pred_list: | |||
| labels.append(self.id2label[pre]) | |||
| labels = labels[1:-1] | |||
| chunks = [] | |||
| chunk = '' | |||
| assert len(inputs['text']) == len(labels) | |||
| for token, label in zip(inputs['text'], labels): | |||
| if label[0] == 'B' or label[0] == 'I': | |||
| chunk += token | |||
| else: | |||
| chunk += token | |||
| chunks.append(chunk) | |||
| chunk = '' | |||
| if chunk: | |||
| chunks.append(chunk) | |||
| seg_result = ' '.join(chunks) | |||
| rst = { | |||
| 'output': seg_result, | |||
| } | |||
| return rst | |||
| @@ -69,6 +69,19 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.text_generation: ['text'], | |||
| # word segmentation result for single sample | |||
| # { | |||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | |||
| # } | |||
| Tasks.word_segmentation: ['output'], | |||
| # sentence similarity result for single sample | |||
| # { | |||
| # "labels": "1", | |||
| # "scores": 0.9 | |||
| # } | |||
| Tasks.sentence_similarity: ['scores', 'labels'], | |||
| # ============ audio tasks =================== | |||
| # ============ multi-modal tasks =================== | |||
| @@ -12,7 +12,7 @@ from .builder import PREPROCESSORS | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor' | |||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor' | |||
| ] | |||
| @@ -171,3 +171,51 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-token-classification') | |||
| class TokenClassifcationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # preprocess the data for the model input | |||
| text = data.replace(' ', '').strip() | |||
| tokens = [] | |||
| for token in text: | |||
| token = self.tokenizer.tokenize(token) | |||
| tokens.extend(token) | |||
| input_ids = self.tokenizer.convert_tokens_to_ids(tokens) | |||
| input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) | |||
| attention_mask = [1] * len(input_ids) | |||
| token_type_ids = [0] * len(input_ids) | |||
| return { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'token_type_ids': token_type_ids | |||
| } | |||
| @@ -30,6 +30,7 @@ class Tasks(object): | |||
| image_matting = 'image-matting' | |||
| # nlp tasks | |||
| word_segmentation = 'word-segmentation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import StructBertForTokenClassification | |||
| from modelscope.pipelines import WordSegmentationPipeline, pipeline | |||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class WordSegmentationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | |||
| sentence = '今天天气不错,适合出去游玩' | |||
| def setUp(self) -> None: | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = TokenClassifcationPreprocessor(cache_path) | |||
| model = StructBertForTokenClassification( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print() | |||
| print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = TokenClassifcationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.word_segmentation, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.word_segmentation) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||