|
|
|
@@ -1,9 +1,10 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import os |
|
|
|
import os.path as osp |
|
|
|
import re |
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
|
import sentencepiece as spm |
|
|
|
import torch |
|
|
|
@@ -13,8 +14,7 @@ from modelscope.metainfo import Models, Preprocessors |
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.preprocessors.base import Preprocessor |
|
|
|
from modelscope.preprocessors.builder import PREPROCESSORS |
|
|
|
from modelscope.utils.config import (Config, ConfigFields, |
|
|
|
use_task_specific_params) |
|
|
|
from modelscope.utils.config import Config, ConfigFields |
|
|
|
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile |
|
|
|
from modelscope.utils.hub import get_model_type, parse_label_mapping |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
@@ -83,6 +83,15 @@ class NLPTokenizerPreprocessorBase(Preprocessor): |
|
|
|
|
|
|
|
self._mode = mode |
|
|
|
self.label = kwargs.pop('label', OutputKeys.LABEL) |
|
|
|
self.use_fast = kwargs.pop('use_fast', None) |
|
|
|
if self.use_fast is None and os.path.isfile( |
|
|
|
os.path.join(model_dir, 'tokenizer_config.json')): |
|
|
|
with open(os.path.join(model_dir, 'tokenizer_config.json'), |
|
|
|
'r') as f: |
|
|
|
json_config = json.load(f) |
|
|
|
self.use_fast = json_config.get('use_fast') |
|
|
|
self.use_fast = False if self.use_fast is None else self.use_fast |
|
|
|
|
|
|
|
self.label2id = None |
|
|
|
if 'label2id' in kwargs: |
|
|
|
self.label2id = kwargs.pop('label2id') |
|
|
|
@@ -118,32 +127,23 @@ class NLPTokenizerPreprocessorBase(Preprocessor): |
|
|
|
if model_type in (Models.structbert, Models.gpt3, Models.palm, |
|
|
|
Models.plug): |
|
|
|
from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast |
|
|
|
return SbertTokenizer.from_pretrained( |
|
|
|
model_dir |
|
|
|
) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained( |
|
|
|
model_dir) |
|
|
|
tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer |
|
|
|
return tokenizer.from_pretrained(model_dir) |
|
|
|
elif model_type == Models.veco: |
|
|
|
from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast |
|
|
|
return VecoTokenizer.from_pretrained( |
|
|
|
model_dir |
|
|
|
) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained( |
|
|
|
model_dir) |
|
|
|
tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer |
|
|
|
return tokenizer.from_pretrained(model_dir) |
|
|
|
elif model_type == Models.deberta_v2: |
|
|
|
from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast |
|
|
|
return DebertaV2Tokenizer.from_pretrained( |
|
|
|
model_dir |
|
|
|
) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained( |
|
|
|
model_dir) |
|
|
|
tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer |
|
|
|
return tokenizer.from_pretrained(model_dir) |
|
|
|
elif not self.is_transformer_based_model: |
|
|
|
from transformers import BertTokenizer, BertTokenizerFast |
|
|
|
return BertTokenizer.from_pretrained( |
|
|
|
model_dir |
|
|
|
) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained( |
|
|
|
model_dir) |
|
|
|
tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer |
|
|
|
return tokenizer.from_pretrained(model_dir) |
|
|
|
else: |
|
|
|
return AutoTokenizer.from_pretrained( |
|
|
|
model_dir, |
|
|
|
use_fast=False if self._mode == ModeKeys.INFERENCE else True) |
|
|
|
model_dir, use_fast=self.use_fast) |
|
|
|
|
|
|
|
def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: |
|
|
|
"""process the raw input data |
|
|
|
@@ -593,9 +593,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): |
|
|
|
else: |
|
|
|
self.is_split_into_words = self.tokenizer.init_kwargs.get( |
|
|
|
'is_split_into_words', False) |
|
|
|
if 'label2id' in kwargs: |
|
|
|
kwargs.pop('label2id') |
|
|
|
self.tokenize_kwargs = kwargs |
|
|
|
|
|
|
|
@type_assert(object, str) |
|
|
|
def __call__(self, data: str) -> Dict[str, Any]: |
|
|
|
|