|
|
|
@@ -1,193 +1,158 @@ |
|
|
|
# coding=utf-8 |
|
|
|
# Copyright 2018 The Google AI Language Team Authors. |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
############################################################################### |
|
|
|
# Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes: |
|
|
|
# - Remove some unused classes and functions |
|
|
|
# - Modify load_vocab, convert_to_unicode, printable_text function |
|
|
|
# - Modify BasicTokenizer class |
|
|
|
# - Add WhiteSpaceTokenizer class |
|
|
|
############################################################################### |
|
|
|
|
|
|
|
# ============================================================================ |
|
|
|
"""Tokenization utilities.""" |
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
|
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
|
|
|
|
import sys |
|
|
|
import collections |
|
|
|
import unicodedata |
|
|
|
import six |
|
|
|
|
|
|
|
def convert_to_unicode(text): |
|
|
|
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" |
|
|
|
if six.PY3: |
|
|
|
def convert_to_printable(text): |
|
|
|
""" |
|
|
|
Converts `text` to a printable coding format. |
|
|
|
""" |
|
|
|
if sys.version_info[0] == 3: |
|
|
|
if isinstance(text, str): |
|
|
|
return text |
|
|
|
if isinstance(text, bytes): |
|
|
|
return text.decode("utf-8", "ignore") |
|
|
|
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
|
if six.PY2: |
|
|
|
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text))) |
|
|
|
if sys.version_info[0] == 2: |
|
|
|
if isinstance(text, str): |
|
|
|
return text.decode("utf-8", "ignore") |
|
|
|
if isinstance(text, unicode): |
|
|
|
return text |
|
|
|
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
|
raise ValueError("Not running on Python2 or Python 3?") |
|
|
|
|
|
|
|
if isinstance(text, unicode): |
|
|
|
return text.encode("utf-8") |
|
|
|
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text))) |
|
|
|
raise ValueError("Only supported when running on Python2 or Python3.") |
|
|
|
|
|
|
|
def printable_text(text): |
|
|
|
"""Returns text encoded in a way suitable for print or `logging`.""" |
|
|
|
|
|
|
|
# These functions want `str` for both Python2 and Python3, but in one case |
|
|
|
# it's a Unicode string and in the other it's a byte string. |
|
|
|
if six.PY3: |
|
|
|
def convert_to_unicode(text): |
|
|
|
""" |
|
|
|
Converts `text` to Unicode format. |
|
|
|
""" |
|
|
|
if sys.version_info[0] == 3: |
|
|
|
if isinstance(text, str): |
|
|
|
return text |
|
|
|
if isinstance(text, bytes): |
|
|
|
return text.decode("utf-8", "ignore") |
|
|
|
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
|
if six.PY2: |
|
|
|
raise ValueError("Only support type `str` or `bytes`, while text type is `%s`" % (type(text))) |
|
|
|
if sys.version_info[0] == 2: |
|
|
|
if isinstance(text, str): |
|
|
|
return text |
|
|
|
return text.decode("utf-8", "ignore") |
|
|
|
if isinstance(text, unicode): |
|
|
|
return text.encode("utf-8") |
|
|
|
raise ValueError("Unsupported string type: %s" % (type(text))) |
|
|
|
raise ValueError("Not running on Python2 or Python 3?") |
|
|
|
return text |
|
|
|
raise ValueError("Only support type `str` or `unicode`, while text type is `%s`" % (type(text))) |
|
|
|
raise ValueError("Only supported when running on Python2 or Python3.") |
|
|
|
|
|
|
|
|
|
|
|
def load_vocab(vocab_file): |
|
|
|
"""Loads a vocabulary file into a dictionary.""" |
|
|
|
vocab = collections.OrderedDict() |
|
|
|
def load_vocab_file(vocab_file): |
|
|
|
""" |
|
|
|
Loads a vocabulary file and turns into a {token:id} dictionary. |
|
|
|
""" |
|
|
|
vocab_dict = collections.OrderedDict() |
|
|
|
index = 0 |
|
|
|
with open(vocab_file, "r") as reader: |
|
|
|
with open(vocab_file, "r") as vocab: |
|
|
|
while True: |
|
|
|
token = convert_to_unicode(reader.readline()) |
|
|
|
token = convert_to_unicode(vocab.readline()) |
|
|
|
if not token: |
|
|
|
break |
|
|
|
token = token.strip() |
|
|
|
vocab[token] = index |
|
|
|
vocab_dict[token] = index |
|
|
|
index += 1 |
|
|
|
return vocab |
|
|
|
return vocab_dict |
|
|
|
|
|
|
|
|
|
|
|
def convert_by_vocab(vocab, items): |
|
|
|
"""Converts a sequence of [tokens|ids] using the vocab.""" |
|
|
|
def convert_by_vocab_dict(vocab_dict, items): |
|
|
|
""" |
|
|
|
Converts a sequence of [tokens|ids] according to the vocab dict. |
|
|
|
""" |
|
|
|
output = [] |
|
|
|
for item in items: |
|
|
|
if item in vocab: |
|
|
|
output.append(vocab[item]) |
|
|
|
if item in vocab_dict: |
|
|
|
output.append(vocab_dict[item]) |
|
|
|
else: |
|
|
|
output.append(vocab["<unk>"]) |
|
|
|
output.append(vocab_dict["<unk>"]) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def convert_tokens_to_ids(vocab, tokens): |
|
|
|
return convert_by_vocab(vocab, tokens) |
|
|
|
|
|
|
|
|
|
|
|
def convert_ids_to_tokens(inv_vocab, ids): |
|
|
|
return convert_by_vocab(inv_vocab, ids) |
|
|
|
|
|
|
|
|
|
|
|
def whitespace_tokenize(text): |
|
|
|
"""Runs basic whitespace cleaning and splitting on a piece of text.""" |
|
|
|
text = text.strip() |
|
|
|
if not text: |
|
|
|
return [] |
|
|
|
tokens = text.split() |
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
class WhiteSpaceTokenizer(): |
|
|
|
"""Runs end-to-end tokenziation.""" |
|
|
|
""" |
|
|
|
Whitespace tokenizer. |
|
|
|
""" |
|
|
|
def __init__(self, vocab_file): |
|
|
|
self.vocab = load_vocab(vocab_file) |
|
|
|
self.inv_vocab = {v: k for k, v in self.vocab.items()} |
|
|
|
self.basic_tokenizer = BasicTokenizer() |
|
|
|
|
|
|
|
def tokenize(self, text): |
|
|
|
return self.basic_tokenizer.tokenize(text) |
|
|
|
|
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
|
|
return convert_by_vocab(self.vocab, tokens) |
|
|
|
|
|
|
|
def convert_ids_to_tokens(self, ids): |
|
|
|
return convert_by_vocab(self.inv_vocab, ids) |
|
|
|
|
|
|
|
|
|
|
|
class BasicTokenizer(): |
|
|
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
"""Constructs a BasicTokenizer.""" |
|
|
|
self.vocab_dict = load_vocab_file(vocab_file) |
|
|
|
self.inv_vocab_dict = {index: token for token, index in self.vocab_dict.items()} |
|
|
|
|
|
|
|
def _is_whitespace_char(self, char): |
|
|
|
""" |
|
|
|
Checks if it is a whitespace character(regard "\t", "\n", "\r" as whitespace here). |
|
|
|
""" |
|
|
|
if char in (" ", "\t", "\n", "\r"): |
|
|
|
return True |
|
|
|
uni = unicodedata.category(char) |
|
|
|
if uni == "Zs": |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def tokenize(self, text): |
|
|
|
"""Tokenizes a piece of text.""" |
|
|
|
text = convert_to_unicode(text) |
|
|
|
text = self._clean_text(text) |
|
|
|
return whitespace_tokenize(text) |
|
|
|
def _is_control_char(self, char): |
|
|
|
""" |
|
|
|
Checks if it is a control character. |
|
|
|
""" |
|
|
|
if char in ("\t", "\n", "\r"): |
|
|
|
return False |
|
|
|
uni = unicodedata.category(char) |
|
|
|
if uni in ("Cc", "Cf"): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def _clean_text(self, text): |
|
|
|
"""Performs invalid character removal and whitespace cleanup on text.""" |
|
|
|
""" |
|
|
|
Remove invalid characters and cleanup whitespace. |
|
|
|
""" |
|
|
|
output = [] |
|
|
|
for char in text: |
|
|
|
cp = ord(char) |
|
|
|
if cp == 0 or cp == 0xfffd or _is_control(char): |
|
|
|
if cp == 0 or cp == 0xfffd or self._is_control_char(char): |
|
|
|
continue |
|
|
|
if _is_whitespace(char): |
|
|
|
if self._is_whitespace_char(char): |
|
|
|
output.append(" ") |
|
|
|
else: |
|
|
|
output.append(char) |
|
|
|
return "".join(output) |
|
|
|
|
|
|
|
def _whitespace_tokenize(self, text): |
|
|
|
""" |
|
|
|
Clean whitespace and split text into tokens. |
|
|
|
""" |
|
|
|
text = text.strip() |
|
|
|
if not text: |
|
|
|
tokens = [] |
|
|
|
else: |
|
|
|
tokens = text.split() |
|
|
|
return tokens |
|
|
|
|
|
|
|
def _is_whitespace(char): |
|
|
|
"""Checks whether `chars` is a whitespace character.""" |
|
|
|
# \t, \n, and \r are technically contorl characters but we treat them |
|
|
|
# as whitespace since they are generally considered as such. |
|
|
|
if char in (" ", "\t", "\n", "\r"): |
|
|
|
return True |
|
|
|
cat = unicodedata.category(char) |
|
|
|
if cat == "Zs": |
|
|
|
return True |
|
|
|
return False |
|
|
|
def tokenize(self, text): |
|
|
|
""" |
|
|
|
Tokenizes text. |
|
|
|
""" |
|
|
|
text = convert_to_unicode(text) |
|
|
|
text = self._clean_text(text) |
|
|
|
tokens = self._whitespace_tokenize(text) |
|
|
|
return tokens |
|
|
|
|
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
|
|
return convert_by_vocab_dict(self.vocab_dict, tokens) |
|
|
|
|
|
|
|
def _is_control(char): |
|
|
|
"""Checks whether `chars` is a control character.""" |
|
|
|
# These are technically control characters but we count them as whitespace |
|
|
|
# characters. |
|
|
|
if char in ("\t", "\n", "\r"): |
|
|
|
return False |
|
|
|
cat = unicodedata.category(char) |
|
|
|
if cat in ("Cc", "Cf"): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _is_punctuation(char): |
|
|
|
"""Checks whether `chars` is a punctuation character.""" |
|
|
|
cp = ord(char) |
|
|
|
# We treat all non-letter/number ASCII as punctuation. |
|
|
|
# Characters such as "^", "$", and "`" are not in the Unicode |
|
|
|
# Punctuation class but we treat them as punctuation anyways, for |
|
|
|
# consistency. |
|
|
|
if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)): |
|
|
|
return True |
|
|
|
cat = unicodedata.category(char) |
|
|
|
if cat.startswith("P"): |
|
|
|
return True |
|
|
|
return False |
|
|
|
def convert_ids_to_tokens(self, ids): |
|
|
|
return convert_by_vocab_dict(self.inv_vocab_dict, ids) |