You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert_tokenizer.py 16 kB

Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. bert_tokenizer.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.
  3. """
  4. import collections
  5. import os
  6. import unicodedata
  7. from io import open
  8. PRETRAINED_VOCAB_ARCHIVE_MAP = {
  9. 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
  10. 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
  11. 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
  12. 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
  13. 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
  14. 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
  15. 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
  16. }
  17. PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
  18. 'bert-base-uncased': 512,
  19. 'bert-large-uncased': 512,
  20. 'bert-base-cased': 512,
  21. 'bert-large-cased': 512,
  22. 'bert-base-multilingual-uncased': 512,
  23. 'bert-base-multilingual-cased': 512,
  24. 'bert-base-chinese': 512,
  25. }
  26. VOCAB_NAME = 'vocab.txt'
  27. def load_vocab(vocab_file):
  28. """Loads a vocabulary file into a dictionary."""
  29. vocab = collections.OrderedDict()
  30. index = 0
  31. with open(vocab_file, "r", encoding="utf-8") as reader:
  32. while True:
  33. token = reader.readline()
  34. if not token:
  35. break
  36. token = token.strip()
  37. vocab[token] = index
  38. index += 1
  39. return vocab
  40. def whitespace_tokenize(text):
  41. """Runs basic whitespace cleaning and splitting on a piece of text."""
  42. text = text.strip()
  43. if not text:
  44. return []
  45. tokens = text.split()
  46. return tokens
  47. class BertTokenizer(object):
  48. """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
  49. def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
  50. never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
  51. """Constructs a BertTokenizer.
  52. Args:
  53. vocab_file: Path to a one-wordpiece-per-line vocabulary file
  54. do_lower_case: Whether to lower case the input
  55. Only has an effect when do_wordpiece_only=False
  56. do_basic_tokenize: Whether to do basic tokenization before wordpiece.
  57. max_len: An artificial maximum length to truncate tokenized sequences to;
  58. Effective maximum length is always the minimum of this
  59. value (if specified) and the underlying BERT model's
  60. sequence length.
  61. never_split: List of tokens which will never be split during tokenization.
  62. Only has an effect when do_wordpiece_only=False
  63. """
  64. if not os.path.isfile(vocab_file):
  65. raise ValueError(
  66. "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
  67. "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
  68. self.vocab = load_vocab(vocab_file)
  69. self.ids_to_tokens = collections.OrderedDict(
  70. [(ids, tok) for tok, ids in self.vocab.items()])
  71. self.do_basic_tokenize = do_basic_tokenize
  72. if do_basic_tokenize:
  73. self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
  74. never_split=never_split)
  75. self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
  76. self.max_len = max_len if max_len is not None else int(1e12)
  77. def tokenize(self, text):
  78. split_tokens = []
  79. if self.do_basic_tokenize:
  80. for token in self.basic_tokenizer.tokenize(text):
  81. for sub_token in self.wordpiece_tokenizer.tokenize(token):
  82. split_tokens.append(sub_token)
  83. else:
  84. split_tokens = self.wordpiece_tokenizer.tokenize(text)
  85. return split_tokens
  86. def convert_tokens_to_ids(self, tokens):
  87. """Converts a sequence of tokens into ids using the vocab."""
  88. ids = []
  89. for token in tokens:
  90. ids.append(self.vocab[token])
  91. if len(ids) > self.max_len:
  92. print(
  93. "WARNING!\n\""
  94. "Token indices sequence length is longer than the specified maximum "
  95. "sequence length for this BERT model ({} > {}). Running this"
  96. " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
  97. )
  98. return ids
  99. def convert_ids_to_tokens(self, ids):
  100. """Converts a sequence of ids in wordpiece tokens using the vocab."""
  101. tokens = []
  102. for i in ids:
  103. tokens.append(self.ids_to_tokens[i])
  104. return tokens
  105. def save_vocabulary(self, vocab_path):
  106. """Save the tokenizer vocabulary to a directory or file."""
  107. index = 0
  108. if os.path.isdir(vocab_path):
  109. vocab_file = os.path.join(vocab_path, VOCAB_NAME)
  110. with open(vocab_file, "w", encoding="utf-8") as writer:
  111. for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
  112. if index != token_index:
  113. print("Saving vocabulary to {}: vocabulary indices are not consecutive."
  114. " Please check that the vocabulary is not corrupted!".format(vocab_file))
  115. index = token_index
  116. writer.write(token + u'\n')
  117. index += 1
  118. return vocab_file
  119. @classmethod
  120. def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
  121. """
  122. Instantiate a PreTrainedBertModel from a pre-trained model file.
  123. Download and cache the pre-trained model file if needed.
  124. """
  125. if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
  126. vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
  127. if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
  128. print("The pre-trained model you are loading is a cased model but you have not set "
  129. "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
  130. "you may want to check this behavior.")
  131. kwargs['do_lower_case'] = False
  132. elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
  133. print("The pre-trained model you are loading is an uncased model but you have set "
  134. "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
  135. "but you may want to check this behavior.")
  136. kwargs['do_lower_case'] = True
  137. else:
  138. vocab_file = pretrained_model_name_or_path
  139. if os.path.isdir(vocab_file):
  140. vocab_file = os.path.join(vocab_file, VOCAB_NAME)
  141. # redirect to the cache, if necessary
  142. resolved_vocab_file = vocab_file
  143. print("loading vocabulary file {}".format(vocab_file))
  144. if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
  145. # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
  146. # than the number of positional embeddings
  147. max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
  148. kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
  149. # Instantiate tokenizer.
  150. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
  151. return tokenizer
  152. class BasicTokenizer(object):
  153. """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
  154. def __init__(self,
  155. do_lower_case=True,
  156. never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
  157. """Constructs a BasicTokenizer.
  158. Args:
  159. do_lower_case: Whether to lower case the input.
  160. """
  161. self.do_lower_case = do_lower_case
  162. self.never_split = never_split
  163. def tokenize(self, text):
  164. """Tokenizes a piece of text."""
  165. text = self._clean_text(text)
  166. # This was added on November 1st, 2018 for the multilingual and Chinese
  167. # models. This is also applied to the English models now, but it doesn't
  168. # matter since the English models were not trained on any Chinese data
  169. # and generally don't have any Chinese data in them (there are Chinese
  170. # characters in the vocabulary because Wikipedia does have some Chinese
  171. # words in the English Wikipedia.).
  172. text = self._tokenize_chinese_chars(text)
  173. orig_tokens = whitespace_tokenize(text)
  174. split_tokens = []
  175. for token in orig_tokens:
  176. if self.do_lower_case and token not in self.never_split:
  177. token = token.lower()
  178. token = self._run_strip_accents(token)
  179. split_tokens.extend(self._run_split_on_punc(token))
  180. output_tokens = whitespace_tokenize(" ".join(split_tokens))
  181. return output_tokens
  182. def _run_strip_accents(self, text):
  183. """Strips accents from a piece of text."""
  184. text = unicodedata.normalize("NFD", text)
  185. output = []
  186. for char in text:
  187. cat = unicodedata.category(char)
  188. if cat == "Mn":
  189. continue
  190. output.append(char)
  191. return "".join(output)
  192. def _run_split_on_punc(self, text):
  193. """Splits punctuation on a piece of text."""
  194. if text in self.never_split:
  195. return [text]
  196. chars = list(text)
  197. i = 0
  198. start_new_word = True
  199. output = []
  200. while i < len(chars):
  201. char = chars[i]
  202. if _is_punctuation(char):
  203. output.append([char])
  204. start_new_word = True
  205. else:
  206. if start_new_word:
  207. output.append([])
  208. start_new_word = False
  209. output[-1].append(char)
  210. i += 1
  211. return ["".join(x) for x in output]
  212. def _tokenize_chinese_chars(self, text):
  213. """Adds whitespace around any CJK character."""
  214. output = []
  215. for char in text:
  216. cp = ord(char)
  217. if self._is_chinese_char(cp):
  218. output.append(" ")
  219. output.append(char)
  220. output.append(" ")
  221. else:
  222. output.append(char)
  223. return "".join(output)
  224. def _is_chinese_char(self, cp):
  225. """Checks whether CP is the codepoint of a CJK character."""
  226. # This defines a "chinese character" as anything in the CJK Unicode block:
  227. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  228. #
  229. # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
  230. # despite its name. The modern Korean Hangul alphabet is a different block,
  231. # as is Japanese Hiragana and Katakana. Those alphabets are used to write
  232. # space-separated words, so they are not treated specially and handled
  233. # like the all of the other languages.
  234. if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
  235. (cp >= 0x3400 and cp <= 0x4DBF) or #
  236. (cp >= 0x20000 and cp <= 0x2A6DF) or #
  237. (cp >= 0x2A700 and cp <= 0x2B73F) or #
  238. (cp >= 0x2B740 and cp <= 0x2B81F) or #
  239. (cp >= 0x2B820 and cp <= 0x2CEAF) or
  240. (cp >= 0xF900 and cp <= 0xFAFF) or #
  241. (cp >= 0x2F800 and cp <= 0x2FA1F)): #
  242. return True
  243. return False
  244. def _clean_text(self, text):
  245. """Performs invalid character removal and whitespace cleanup on text."""
  246. output = []
  247. for char in text:
  248. cp = ord(char)
  249. if cp == 0 or cp == 0xfffd or _is_control(char):
  250. continue
  251. if _is_whitespace(char):
  252. output.append(" ")
  253. else:
  254. output.append(char)
  255. return "".join(output)
  256. class WordpieceTokenizer(object):
  257. """Runs WordPiece tokenization."""
  258. def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
  259. self.vocab = vocab
  260. self.unk_token = unk_token
  261. self.max_input_chars_per_word = max_input_chars_per_word
  262. def tokenize(self, text):
  263. """Tokenizes a piece of text into its word pieces.
  264. This uses a greedy longest-match-first algorithm to perform tokenization
  265. using the given vocabulary.
  266. For example:
  267. input = "unaffable"
  268. output = ["un", "##aff", "##able"]
  269. Args:
  270. text: A single token or whitespace separated tokens. This should have
  271. already been passed through `BasicTokenizer`.
  272. Returns:
  273. A list of wordpiece tokens.
  274. """
  275. output_tokens = []
  276. for token in whitespace_tokenize(text):
  277. chars = list(token)
  278. if len(chars) > self.max_input_chars_per_word:
  279. output_tokens.append(self.unk_token)
  280. continue
  281. is_bad = False
  282. start = 0
  283. sub_tokens = []
  284. while start < len(chars):
  285. end = len(chars)
  286. cur_substr = None
  287. while start < end:
  288. substr = "".join(chars[start:end])
  289. if start > 0:
  290. substr = "##" + substr
  291. if substr in self.vocab:
  292. cur_substr = substr
  293. break
  294. end -= 1
  295. if cur_substr is None:
  296. is_bad = True
  297. break
  298. sub_tokens.append(cur_substr)
  299. start = end
  300. if is_bad:
  301. output_tokens.append(self.unk_token)
  302. else:
  303. output_tokens.extend(sub_tokens)
  304. return output_tokens
  305. def _is_whitespace(char):
  306. """Checks whether `chars` is a whitespace character."""
  307. # \t, \n, and \r are technically contorl characters but we treat them
  308. # as whitespace since they are generally considered as such.
  309. if char == " " or char == "\t" or char == "\n" or char == "\r":
  310. return True
  311. cat = unicodedata.category(char)
  312. if cat == "Zs":
  313. return True
  314. return False
  315. def _is_control(char):
  316. """Checks whether `chars` is a control character."""
  317. # These are technically control characters but we count them as whitespace
  318. # characters.
  319. if char == "\t" or char == "\n" or char == "\r":
  320. return False
  321. cat = unicodedata.category(char)
  322. if cat.startswith("C"):
  323. return True
  324. return False
  325. def _is_punctuation(char):
  326. """Checks whether `chars` is a punctuation character."""
  327. cp = ord(char)
  328. # We treat all non-letter/number ASCII as punctuation.
  329. # Characters such as "^", "$", and "`" are not in the Unicode
  330. # Punctuation class but we treat them as punctuation anyways, for
  331. # consistency.
  332. if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
  333. (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
  334. return True
  335. cat = unicodedata.category(char)
  336. if cat.startswith("P"):
  337. return True
  338. return False