You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tokenization.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. ###############################################################################
  16. # Modified by Huawei Technologies Co., Ltd, May, 2020, with following changes:
  17. # - Remove some unused classes and functions
  18. # - Modify load_vocab, convert_to_unicode, printable_text function
  19. # - Modify BasicTokenizer class
  20. # - Add WhiteSpaceTokenizer class
  21. ###############################################################################
  22. """Tokenization utilities."""
  23. from __future__ import absolute_import
  24. from __future__ import division
  25. from __future__ import print_function
  26. import collections
  27. import unicodedata
  28. import six
  29. def convert_to_unicode(text):
  30. """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
  31. if six.PY3:
  32. if isinstance(text, str):
  33. return text
  34. if isinstance(text, bytes):
  35. return text.decode("utf-8", "ignore")
  36. raise ValueError("Unsupported string type: %s" % (type(text)))
  37. if six.PY2:
  38. if isinstance(text, str):
  39. return text.decode("utf-8", "ignore")
  40. if isinstance(text, unicode):
  41. return text
  42. raise ValueError("Unsupported string type: %s" % (type(text)))
  43. raise ValueError("Not running on Python2 or Python 3?")
  44. def printable_text(text):
  45. """Returns text encoded in a way suitable for print or `logging`."""
  46. # These functions want `str` for both Python2 and Python3, but in one case
  47. # it's a Unicode string and in the other it's a byte string.
  48. if six.PY3:
  49. if isinstance(text, str):
  50. return text
  51. if isinstance(text, bytes):
  52. return text.decode("utf-8", "ignore")
  53. raise ValueError("Unsupported string type: %s" % (type(text)))
  54. if six.PY2:
  55. if isinstance(text, str):
  56. return text
  57. if isinstance(text, unicode):
  58. return text.encode("utf-8")
  59. raise ValueError("Unsupported string type: %s" % (type(text)))
  60. raise ValueError("Not running on Python2 or Python 3?")
  61. def load_vocab(vocab_file):
  62. """Loads a vocabulary file into a dictionary."""
  63. vocab = collections.OrderedDict()
  64. index = 0
  65. with open(vocab_file, "r") as reader:
  66. while True:
  67. token = convert_to_unicode(reader.readline())
  68. if not token:
  69. break
  70. token = token.strip()
  71. vocab[token] = index
  72. index += 1
  73. return vocab
  74. def convert_by_vocab(vocab, items):
  75. """Converts a sequence of [tokens|ids] using the vocab."""
  76. output = []
  77. for item in items:
  78. if item in vocab:
  79. output.append(vocab[item])
  80. else:
  81. output.append(vocab["<unk>"])
  82. return output
  83. def convert_tokens_to_ids(vocab, tokens):
  84. return convert_by_vocab(vocab, tokens)
  85. def convert_ids_to_tokens(inv_vocab, ids):
  86. return convert_by_vocab(inv_vocab, ids)
  87. def whitespace_tokenize(text):
  88. """Runs basic whitespace cleaning and splitting on a piece of text."""
  89. text = text.strip()
  90. if not text:
  91. return []
  92. tokens = text.split()
  93. return tokens
  94. class WhiteSpaceTokenizer():
  95. """Runs end-to-end tokenziation."""
  96. def __init__(self, vocab_file):
  97. self.vocab = load_vocab(vocab_file)
  98. self.inv_vocab = {v: k for k, v in self.vocab.items()}
  99. self.basic_tokenizer = BasicTokenizer()
  100. def tokenize(self, text):
  101. return self.basic_tokenizer.tokenize(text)
  102. def convert_tokens_to_ids(self, tokens):
  103. return convert_by_vocab(self.vocab, tokens)
  104. def convert_ids_to_tokens(self, ids):
  105. return convert_by_vocab(self.inv_vocab, ids)
  106. class BasicTokenizer():
  107. """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
  108. def __init__(self):
  109. """Constructs a BasicTokenizer."""
  110. def tokenize(self, text):
  111. """Tokenizes a piece of text."""
  112. text = convert_to_unicode(text)
  113. text = self._clean_text(text)
  114. return whitespace_tokenize(text)
  115. def _clean_text(self, text):
  116. """Performs invalid character removal and whitespace cleanup on text."""
  117. output = []
  118. for char in text:
  119. cp = ord(char)
  120. if cp == 0 or cp == 0xfffd or _is_control(char):
  121. continue
  122. if _is_whitespace(char):
  123. output.append(" ")
  124. else:
  125. output.append(char)
  126. return "".join(output)
  127. def _is_whitespace(char):
  128. """Checks whether `chars` is a whitespace character."""
  129. # \t, \n, and \r are technically contorl characters but we treat them
  130. # as whitespace since they are generally considered as such.
  131. if char in (" ", "\t", "\n", "\r"):
  132. return True
  133. cat = unicodedata.category(char)
  134. if cat == "Zs":
  135. return True
  136. return False
  137. def _is_control(char):
  138. """Checks whether `chars` is a control character."""
  139. # These are technically control characters but we count them as whitespace
  140. # characters.
  141. if char in ("\t", "\n", "\r"):
  142. return False
  143. cat = unicodedata.category(char)
  144. if cat in ("Cc", "Cf"):
  145. return True
  146. return False
  147. def _is_punctuation(char):
  148. """Checks whether `chars` is a punctuation character."""
  149. cp = ord(char)
  150. # We treat all non-letter/number ASCII as punctuation.
  151. # Characters such as "^", "$", and "`" are not in the Unicode
  152. # Punctuation class but we treat them as punctuation anyways, for
  153. # consistency.
  154. if ((33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126)):
  155. return True
  156. cat = unicodedata.category(char)
  157. if cat.startswith("P"):
  158. return True
  159. return False