Browse Source

[to #42322933]Fix the logic of fast tokenizer

1. Change the logic of using fast tokenizer from mode to user arguments and tokenizer_config.json
This is to fix the problem of RANER must use fast tokenizer in some special models.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10488982
master
yuze.zyz 3 years ago
parent
commit
f7f7eb21dc
1 changed files with 22 additions and 25 deletions
  1. +22
    -25
      modelscope/preprocessors/nlp/nlp_base.py

+ 22
- 25
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import json
import numpy as np
import sentencepiece as spm
import torch
@@ -13,8 +14,7 @@ from modelscope.metainfo import Models, Preprocessors
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.config import (Config, ConfigFields,
use_task_specific_params)
from modelscope.utils.config import Config, ConfigFields
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile
from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.logger import get_logger
@@ -83,6 +83,15 @@ class NLPTokenizerPreprocessorBase(Preprocessor):

self._mode = mode
self.label = kwargs.pop('label', OutputKeys.LABEL)
self.use_fast = kwargs.pop('use_fast', None)
if self.use_fast is None and os.path.isfile(
os.path.join(model_dir, 'tokenizer_config.json')):
with open(os.path.join(model_dir, 'tokenizer_config.json'),
'r') as f:
json_config = json.load(f)
self.use_fast = json_config.get('use_fast')
self.use_fast = False if self.use_fast is None else self.use_fast

self.label2id = None
if 'label2id' in kwargs:
self.label2id = kwargs.pop('label2id')
@@ -118,32 +127,23 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
if model_type in (Models.structbert, Models.gpt3, Models.palm,
Models.plug):
from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast
return SbertTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained(
model_dir)
tokenizer = SbertTokenizerFast if self.use_fast else SbertTokenizer
return tokenizer.from_pretrained(model_dir)
elif model_type == Models.veco:
from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast
return VecoTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained(
model_dir)
tokenizer = VecoTokenizerFast if self.use_fast else VecoTokenizer
return tokenizer.from_pretrained(model_dir)
elif model_type == Models.deberta_v2:
from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast
return DebertaV2Tokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained(
model_dir)
tokenizer = DebertaV2TokenizerFast if self.use_fast else DebertaV2Tokenizer
return tokenizer.from_pretrained(model_dir)
elif not self.is_transformer_based_model:
from transformers import BertTokenizer, BertTokenizerFast
return BertTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained(
model_dir)
tokenizer = BertTokenizerFast if self.use_fast else BertTokenizer
return tokenizer.from_pretrained(model_dir)
else:
return AutoTokenizer.from_pretrained(
model_dir,
use_fast=False if self._mode == ModeKeys.INFERENCE else True)
model_dir, use_fast=self.use_fast)

def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data
@@ -593,9 +593,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
else:
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')
self.tokenize_kwargs = kwargs

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:


Loading…
Cancel
Save