You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # coding: UTF-8
  2. import os
  3. import torch
  4. import numpy as np
  5. import pickle as pkl
  6. from tqdm import tqdm
  7. import time
  8. from datetime import timedelta
  9. MAX_VOCAB_SIZE = 10000 # 词表长度限制
  10. UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号
  11. def build_vocab(file_path, tokenizer, max_size, min_freq):
  12. vocab_dic = {}
  13. with open(file_path, 'r', encoding='UTF-8') as f:
  14. for line in tqdm(f):
  15. lin = line.strip()
  16. if not lin:
  17. continue
  18. content = lin.split('\t')[0]
  19. for word in tokenizer(content):
  20. vocab_dic[word] = vocab_dic.get(word, 0) + 1
  21. vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
  22. vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
  23. vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1})
  24. return vocab_dic
  25. def build_dataset(config, ues_word):
  26. if ues_word:
  27. tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level
  28. else:
  29. tokenizer = lambda x: [y for y in x] # char-level
  30. if os.path.exists(config.vocab_path):
  31. vocab = pkl.load(open(config.vocab_path, 'rb'))
  32. else:
  33. vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  34. pkl.dump(vocab, open(config.vocab_path, 'wb'))
  35. print(f"Vocab size: {len(vocab)}")
  36. def load_dataset(path, pad_size=32):
  37. contents = []
  38. with open(path, 'r', encoding='UTF-8') as f:
  39. for line in tqdm(f):
  40. lin = line.strip()
  41. if not lin:
  42. continue
  43. content, label = lin.split('\t')
  44. words_line = []
  45. token = tokenizer(content)
  46. seq_len = len(token)
  47. if pad_size:
  48. if len(token) < pad_size:
  49. token.extend([PAD] * (pad_size - len(token)))
  50. else:
  51. token = token[:pad_size]
  52. seq_len = pad_size
  53. # word to id
  54. for word in token:
  55. words_line.append(vocab.get(word, vocab.get(UNK)))
  56. contents.append((words_line, int(label), seq_len))
  57. return contents # [([...], 0), ([...], 1), ...]
  58. train = load_dataset(config.train_path, config.pad_size)
  59. dev = load_dataset(config.dev_path, config.pad_size)
  60. test = load_dataset(config.test_path, config.pad_size)
  61. return vocab, train, dev, test
  62. class DatasetIterater(object):
  63. def __init__(self, batches, batch_size, device):
  64. self.batch_size = batch_size
  65. self.batches = batches
  66. self.n_batches = len(batches) // batch_size
  67. self.residue = False # 记录batch数量是否为整数
  68. if len(batches) % self.n_batches != 0:
  69. self.residue = True
  70. self.index = 0
  71. self.device = device
  72. def _to_tensor(self, datas):
  73. x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
  74. y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
  75. # pad前的长度(超过pad_size的设为pad_size)
  76. seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
  77. return (x, seq_len), y
  78. def __next__(self):
  79. if self.residue and self.index == self.n_batches:
  80. batches = self.batches[self.index * self.batch_size: len(self.batches)]
  81. self.index += 1
  82. batches = self._to_tensor(batches)
  83. return batches
  84. elif self.index >= self.n_batches:
  85. self.index = 0
  86. raise StopIteration
  87. else:
  88. batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
  89. self.index += 1
  90. batches = self._to_tensor(batches)
  91. return batches
  92. def __iter__(self):
  93. return self
  94. def __len__(self):
  95. if self.residue:
  96. return self.n_batches + 1
  97. else:
  98. return self.n_batches
  99. def build_iterator(dataset, config):
  100. iter = DatasetIterater(dataset, config.batch_size, config.device)
  101. return iter
  102. def get_time_dif(start_time):
  103. """获取已使用时间"""
  104. end_time = time.time()
  105. time_dif = end_time - start_time
  106. return timedelta(seconds=int(round(time_dif)))
  107. if __name__ == "__main__":
  108. '''提取预训练词向量'''
  109. # 下面的目录、文件名按需更改。
  110. train_dir = "./THUCNews/data/train.txt"
  111. vocab_dir = "./THUCNews/data/vocab.pkl"
  112. pretrain_dir = "./THUCNews/data/sgns.sogou.char"
  113. emb_dim = 300
  114. filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews"
  115. if os.path.exists(vocab_dir):
  116. word_to_id = pkl.load(open(vocab_dir, 'rb'))
  117. else:
  118. # tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开)
  119. tokenizer = lambda x: [y for y in x] # 以字为单位构建词表
  120. word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
  121. pkl.dump(word_to_id, open(vocab_dir, 'wb'))
  122. embeddings = np.random.rand(len(word_to_id), emb_dim)
  123. f = open(pretrain_dir, "r", encoding='UTF-8')
  124. for i, line in enumerate(f.readlines()):
  125. # if i == 0: # 若第一行是标题,则跳过
  126. # continue
  127. lin = line.strip().split(" ")
  128. if lin[0] in word_to_id:
  129. idx = word_to_id[lin[0]]
  130. emb = [float(x) for x in lin[1:301]]
  131. embeddings[idx] = np.asarray(emb, dtype='float32')
  132. f.close()
  133. np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)