|
|
|
@@ -18,11 +18,19 @@ class FaqQuestionAnsweringPreprocessor(NLPBasePreprocessor): |
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|
super(FaqQuestionAnsweringPreprocessor, self).__init__( |
|
|
|
model_dir, mode=ModeKeys.INFERENCE, **kwargs) |
|
|
|
|
|
|
|
from transformers import BertTokenizer |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
|
|
preprocessor_config = Config.from_file( |
|
|
|
os.path.join(model_dir, ModelFile.CONFIGURATION)).get( |
|
|
|
ConfigFields.preprocessor, {}) |
|
|
|
if preprocessor_config.get('tokenizer', |
|
|
|
'BertTokenizer') == 'XLMRoberta': |
|
|
|
from transformers import XLMRobertaTokenizer |
|
|
|
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_dir) |
|
|
|
else: |
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
|
|
self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) |
|
|
|
self.label_dict = None |
|
|
|
|
|
|
|
|