|
- # coding=utf-8
- # Copyright 2018 The Google AI Language Team Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- ###############################################################################
- # Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes:
- # - Remove some unused classes and functions
- # - Modify load_vocab, convert_to_unicode, printable_text function
- # - Modify BasicTokenizer class
- # - Add WhiteSpaceTokenizer class
- ###############################################################################
-
- """Tokenization utilities."""
-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import collections
- import unicodedata
- import six
-
- def convert_to_unicode(text):
- """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
- if six.PY3:
- if isinstance(text, str):
- return text
- if isinstance(text, bytes):
- return text.decode("utf-8", "ignore")
- raise ValueError("Unsupported string type: %s" % (type(text)))
- if six.PY2:
- if isinstance(text, str):
- return text.decode("utf-8", "ignore")
- if isinstance(text, unicode):
- return text
- raise ValueError("Unsupported string type: %s" % (type(text)))
- raise ValueError("Not running on Python2 or Python 3?")
-
-
- def printable_text(text):
- """Returns text encoded in a way suitable for print or `logging`."""
-
- # These functions want `str` for both Python2 and Python3, but in one case
- # it's a Unicode string and in the other it's a byte string.
- if six.PY3:
- if isinstance(text, str):
- return text
- if isinstance(text, bytes):
- return text.decode("utf-8", "ignore")
- raise ValueError("Unsupported string type: %s" % (type(text)))
- if six.PY2:
- if isinstance(text, str):
- return text
- if isinstance(text, unicode):
- return text.encode("utf-8")
- raise ValueError("Unsupported string type: %s" % (type(text)))
- raise ValueError("Not running on Python2 or Python 3?")
-
-
- def load_vocab(vocab_file):
- """Loads a vocabulary file into a dictionary."""
- vocab = collections.OrderedDict()
- index = 0
- with open(vocab_file, "r") as reader:
- while True:
- token = convert_to_unicode(reader.readline())
- if not token:
- break
- token = token.strip()
- vocab[token] = index
- index += 1
- return vocab
-
-
- def convert_by_vocab(vocab, items):
- """Converts a sequence of [tokens|ids] using the vocab."""
- output = []
- for item in items:
- if item in vocab:
- output.append(vocab[item])
- else:
- output.append(vocab["<unk>"])
- return output
-
-
- def convert_tokens_to_ids(vocab, tokens):
- return convert_by_vocab(vocab, tokens)
-
-
- def convert_ids_to_tokens(inv_vocab, ids):
- return convert_by_vocab(inv_vocab, ids)
-
-
- def whitespace_tokenize(text):
- """Runs basic whitespace cleaning and splitting on a piece of text."""
- text = text.strip()
- if not text:
- return []
- tokens = text.split()
- return tokens
-
-
- class WhiteSpaceTokenizer():
- """Runs end-to-end tokenziation."""
- def __init__(self, vocab_file):
- self.vocab = load_vocab(vocab_file)
- self.inv_vocab = {v: k for k, v in self.vocab.items()}
- self.basic_tokenizer = BasicTokenizer()
-
- def tokenize(self, text):
- return self.basic_tokenizer.tokenize(text)
-
- def convert_tokens_to_ids(self, tokens):
- return convert_by_vocab(self.vocab, tokens)
-
- def convert_ids_to_tokens(self, ids):
- return convert_by_vocab(self.inv_vocab, ids)
-
-
- class BasicTokenizer():
- """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
-
- def __init__(self):
- """Constructs a BasicTokenizer."""
-
- def tokenize(self, text):
- """Tokenizes a piece of text."""
- text = convert_to_unicode(text)
- text = self._clean_text(text)
- return whitespace_tokenize(text)
-
- def _clean_text(self, text):
- """Performs invalid character removal and whitespace cleanup on text."""
- output = []
- for char in text:
- cp = ord(char)
- if cp == 0 or cp == 0xfffd or _is_control(char):
- continue
- if _is_whitespace(char):
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
-
- def _is_whitespace(char):
- """Checks whether `chars` is a whitespace character."""
- # \t, \n, and \r are technically contorl characters but we treat them
- # as whitespace since they are generally considered as such.
- if char in (" ", "\t", "\n", "\r"):
- return True
- cat = unicodedata.category(char)
- if cat == "Zs":
- return True
- return False
-
-
- def _is_control(char):
- """Checks whether `chars` is a control character."""
- # These are technically control characters but we count them as whitespace
- # characters.
- if char in ("\t", "\n", "\r"):
- return False
- cat = unicodedata.category(char)
- if cat in ("Cc", "Cf"):
- return True
- return False
-
-
- def _is_punctuation(char):
- """Checks whether `chars` is a punctuation character."""
- cp = ord(char)
- # We treat all non-letter/number ASCII as punctuation.
- # Characters such as "^", "$", and "`" are not in the Unicode
- # Punctuation class but we treat them as punctuation anyways, for
- # consistency.
- if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)):
- return True
- cat = unicodedata.category(char)
- if cat.startswith("P"):
- return True
- return False
|