|
|
|
@@ -238,7 +238,16 @@ class TokenClassificationTransformersPreprocessor( |
|
|
|
is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg( |
|
|
|
'is_split_into_words', False) |
|
|
|
if is_split_into_words: |
|
|
|
tokens = list(tokens) |
|
|
|
# for supporting prompt seperator, should split twice. [SEP] for default. |
|
|
|
sep_idx = tokens.find('[SEP]') |
|
|
|
if sep_idx == -1 or self.is_lstm_model: |
|
|
|
tokens = list(tokens) |
|
|
|
else: |
|
|
|
tmp_tokens = [] |
|
|
|
tmp_tokens.extend(list(tokens[:sep_idx])) |
|
|
|
tmp_tokens.append('[SEP]') |
|
|
|
tmp_tokens.extend(list(tokens[sep_idx + 5:])) |
|
|
|
tokens = tmp_tokens |
|
|
|
|
|
|
|
if is_split_into_words and self.mode == ModeKeys.INFERENCE: |
|
|
|
encodings, word_ids = self._tokenize_text_by_words( |
|
|
|
@@ -250,6 +259,16 @@ class TokenClassificationTransformersPreprocessor( |
|
|
|
encodings, word_ids = self._tokenize_text_with_slow_tokenizer( |
|
|
|
tokens, **kwargs) |
|
|
|
|
|
|
|
# modify label mask, mask all prompt tokens (tokens after sep token) |
|
|
|
sep_idx = -1 |
|
|
|
for idx, token_id in enumerate(encodings['input_ids']): |
|
|
|
if token_id == self.nlp_tokenizer.tokenizer.sep_token_id: |
|
|
|
sep_idx = idx |
|
|
|
break |
|
|
|
if sep_idx != -1: |
|
|
|
for i in range(sep_idx, len(encodings['label_mask'])): |
|
|
|
encodings['label_mask'][i] = False |
|
|
|
|
|
|
|
if self.mode == ModeKeys.INFERENCE: |
|
|
|
for key in encodings.keys(): |
|
|
|
encodings[key] = torch.tensor(encodings[key]).unsqueeze(0) |
|
|
|
|