You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils_fasttext.py 8.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # coding: UTF-8
  2. import os
  3. import torch
  4. import numpy as np
  5. import pickle as pkl
  6. from tqdm import tqdm
  7. import time
  8. from datetime import timedelta
  9. import jieba
  10. MAX_VOCAB_SIZE = 10000
  11. UNK, PAD = '<UNK>', '<PAD>'
  12. def build_vocab(file_path, tokenizer, max_size, min_freq):
  13. vocab_dic = {}
  14. with open(file_path, 'r', encoding='UTF-8') as f:
  15. for line in tqdm(f):
  16. lin = line.strip()
  17. if not lin:
  18. continue
  19. content = lin.split('\t')[0]
  20. for word in tokenizer(content):
  21. vocab_dic[word] = vocab_dic.get(word, 0) + 1
  22. vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
  23. vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
  24. vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
  25. return vocab_dic
  26. def build_dataset(config, ues_word):
  27. if ues_word:
  28. # tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
  29. tokenizer = lambda x: list(jieba.cut(x)) # 以空格隔开,word-level
  30. else:
  31. tokenizer = lambda x: [y for y in x] # char-level
  32. if os.path.exists(config.vocab_path):
  33. vocab = pkl.load(open(config.vocab_path, 'rb'))
  34. else:
  35. vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  36. pkl.dump(vocab, open(config.vocab_path, 'wb'))
  37. print(f"Vocab size: {len(vocab)}")
  38. def biGramHash(sequence, t, buckets):
  39. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  40. return (t1 * 14918087 + sequence[t]) % buckets
  41. def triGramHash(sequence, t, buckets):
  42. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  43. t2 = sequence[t - 2] if t - 2 >= 0 else 0
  44. return (t2 * 14918087 * 18408749 + t1 * 14918087 + sequence[t]) % buckets
  45. def load_dataset(path, pad_size=32):
  46. contents = []
  47. with open(path, 'r', encoding='UTF-8') as f:
  48. for line in tqdm(f):
  49. lin = line.strip()
  50. if not lin:
  51. continue
  52. try:
  53. content, label = lin.split('\t')
  54. except:
  55. print(line)
  56. words_line = []
  57. token = tokenizer(content)
  58. seq_len = len(token)
  59. if pad_size:
  60. if len(token) < pad_size:
  61. token.extend([PAD] * (pad_size - len(token)))
  62. else:
  63. token = token[:pad_size]
  64. seq_len = pad_size
  65. # word to id
  66. for word in token:
  67. words_line.append(vocab.get(word, vocab.get(UNK)))
  68. # fasttext ngram
  69. buckets = config.n_gram_vocab
  70. bigram = []
  71. trigram = []
  72. # ------ngram------
  73. for i in range(pad_size):
  74. bigram.append(biGramHash(words_line, i, buckets))
  75. trigram.append(triGramHash(words_line, i, buckets))
  76. # -----------------
  77. contents.append((words_line, int(label), seq_len, bigram, trigram))
  78. return contents # [([...], 0), ([...], 1), ...]
  79. train = load_dataset(config.train_path, config.pad_size)
  80. dev = load_dataset(config.dev_path, config.pad_size)
  81. test = load_dataset(config.test_path, config.pad_size)
  82. return vocab, train, dev, test
  83. def build_pre_dataset(config, ues_word,vocab):
  84. if ues_word:
  85. tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
  86. else:
  87. tokenizer = lambda x: [y for y in x] # char-level
  88. def biGramHash(sequence, t, buckets):
  89. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  90. return (t1 * 14918087) % buckets
  91. def triGramHash(sequence, t, buckets):
  92. t1 = sequence[t - 1] if t - 1 >= 0 else 0
  93. t2 = sequence[t - 2] if t - 2 >= 0 else 0
  94. return (t2 * 14918087 * 18408749 + t1 * 14918087) % buckets
  95. def load_pre_dataset(path, pad_size=32):
  96. contents = []
  97. with open(path, 'r', encoding='UTF-8') as f:
  98. all_text=[]
  99. data = f.read()
  100. data = data.split('\n\n')
  101. for sample in data:
  102. content = sample.strip().strip('\n')
  103. if not content:
  104. continue
  105. all_text.append(content)
  106. words_line = []
  107. token = tokenizer(content)
  108. seq_len = len(token)
  109. if pad_size:
  110. if len(token) < pad_size:
  111. token.extend([PAD] * (pad_size - len(token)))
  112. else:
  113. token = token[:pad_size]
  114. seq_len = pad_size
  115. # word to id
  116. for word in token:
  117. words_line.append(vocab.get(word, vocab.get(UNK)))
  118. # fasttext ngram
  119. buckets = config.n_gram_vocab
  120. bigram = []
  121. trigram = []
  122. # ------ngram------
  123. for i in range(pad_size):
  124. bigram.append(biGramHash(words_line, i, buckets))
  125. trigram.append(triGramHash(words_line, i, buckets))
  126. contents.append((words_line, int(-1), seq_len, bigram, trigram))
  127. return contents, all_text # contents:[([...], 0), ([...], 1), ...]
  128. predict, all_text = load_pre_dataset(config.predict_path, config.pad_size)
  129. return predict, all_text
  130. class DatasetIterater(object):
  131. def __init__(self, batches, batch_size, device):
  132. self.batch_size = batch_size
  133. self.batches = batches
  134. self.n_batches = len(batches) // batch_size
  135. self.residue = False # 记录batch数量是否为整数
  136. if len(batches) % self.batch_size != 0:#if len(batches) % self.n_batches != 0:
  137. self.residue = True
  138. self.index = 0
  139. self.device = device
  140. def _to_tensor(self, datas):
  141. # xx = [xxx[2] for xxx in datas]
  142. # indexx = np.argsort(xx)[::-1]
  143. # datas = np.array(datas)[indexx]
  144. x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
  145. y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
  146. bigram = torch.LongTensor([_[3] for _ in datas]).to(self.device)
  147. trigram = torch.LongTensor([_[4] for _ in datas]).to(self.device)
  148. # pad前的长度(超过pad_size的设为pad_size)
  149. seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
  150. return (x, seq_len, bigram, trigram), y
  151. def __next__(self):
  152. if self.residue and self.index == self.n_batches:
  153. batches = self.batches[self.index * self.batch_size: len(self.batches)]
  154. self.index += 1
  155. # print("batches lengith:",len(batches))
  156. batches = self._to_tensor(batches)
  157. return batches
  158. elif self.index >= self.n_batches:
  159. self.index = 0
  160. # print("StopIteration")
  161. raise StopIteration
  162. else:
  163. batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
  164. self.index += 1
  165. # print("batches lengith:",len(batches))
  166. batches = self._to_tensor(batches)
  167. return batches
  168. def __iter__(self):
  169. return self
  170. def __len__(self):
  171. if self.residue:
  172. return self.n_batches + 1
  173. else:
  174. return self.n_batches
  175. def build_iterator(dataset, config):
  176. iter = DatasetIterater(dataset, config.batch_size, config.device)
  177. return iter
  178. def get_time_dif(start_time):
  179. """获取已使用时间"""
  180. end_time = time.time()
  181. time_dif = end_time - start_time
  182. return timedelta(seconds=int(round(time_dif)))
  183. if __name__ == "__main__":
  184. '''提取预训练词向量'''
  185. vocab_dir = "./THUCNews/data/vocab.pkl"
  186. pretrain_dir = "./THUCNews/data/sgns.sogou.char"
  187. emb_dim = 300
  188. filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"
  189. word_to_id = pkl.load(open(vocab_dir, 'rb'))
  190. embeddings = np.random.rand(len(word_to_id), emb_dim)
  191. f = open(pretrain_dir, "r", encoding='UTF-8')
  192. for i, line in enumerate(f.readlines()):
  193. # if i == 0: # 若第一行是标题,则跳过
  194. # continue
  195. lin = line.strip().split(" ")
  196. if lin[0] in word_to_id:
  197. idx = word_to_id[lin[0]]
  198. emb = [float(x) for x in lin[1:301]]
  199. embeddings[idx] = np.asarray(emb, dtype='float32')
  200. f.close()
  201. np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)

No Description

Contributors (1)