You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import os
  2. import pickle
  3. import numpy as np
  4. import torch
  5. from model import *
  6. class SentIter:
  7. def __init__(self, dirname, count):
  8. self.dirname = dirname
  9. self.count = int(count)
  10. def __iter__(self):
  11. for f in os.listdir(self.dirname)[:self.count]:
  12. with open(os.path.join(self.dirname, f), 'rb') as f:
  13. for y, x in pickle.load(f):
  14. for sent in x:
  15. yield sent
  16. def train_word_vec():
  17. # load data
  18. dirname = 'reviews'
  19. sents = SentIter(dirname, 238)
  20. # define models and train
  21. model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5)
  22. model.build_vocab(sents)
  23. model.train(sents, total_examples=model.corpus_count, epochs=10)
  24. model.save('yelp.word2vec')
  25. print(model.wv.similarity('woman', 'man'))
  26. print(model.wv.similarity('nice', 'awful'))
  27. class Embedding_layer:
  28. def __init__(self, wv, vector_size):
  29. self.wv = wv
  30. self.vector_size = vector_size
  31. def get_vec(self, w):
  32. try:
  33. v = self.wv[w]
  34. except KeyError as e:
  35. v = np.random.randn(self.vector_size)
  36. return v
  37. from torch.utils.data import DataLoader, Dataset
  38. class YelpDocSet(Dataset):
  39. def __init__(self, dirname, start_file, num_files, embedding):
  40. self.dirname = dirname
  41. self.num_files = num_files
  42. self._files = os.listdir(dirname)[start_file:start_file + num_files]
  43. self.embedding = embedding
  44. self._cache = [(-1, None) for i in range(5)]
  45. def get_doc(self, n):
  46. file_id = n // 5000
  47. idx = file_id % 5
  48. if self._cache[idx][0] != file_id:
  49. with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
  50. self._cache[idx] = (file_id, pickle.load(f))
  51. y, x = self._cache[idx][1][n % 5000]
  52. sents = []
  53. for s_list in x:
  54. sents.append(' '.join(s_list))
  55. x = '\n'.join(sents)
  56. return x, y - 1
  57. def __len__(self):
  58. return len(self._files) * 5000
  59. def __getitem__(self, n):
  60. file_id = n // 5000
  61. idx = file_id % 5
  62. if self._cache[idx][0] != file_id:
  63. print('load {} to {}'.format(file_id, idx))
  64. with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
  65. self._cache[idx] = (file_id, pickle.load(f))
  66. y, x = self._cache[idx][1][n % 5000]
  67. doc = []
  68. for sent in x:
  69. if len(sent) == 0:
  70. continue
  71. sent_vec = []
  72. for word in sent:
  73. vec = self.embedding.get_vec(word)
  74. sent_vec.append(vec.tolist())
  75. sent_vec = torch.Tensor(sent_vec)
  76. doc.append(sent_vec)
  77. if len(doc) == 0:
  78. doc = [torch.zeros(1, 200)]
  79. return doc, y - 1
  80. def collate(iterable):
  81. y_list = []
  82. x_list = []
  83. for x, y in iterable:
  84. y_list.append(y)
  85. x_list.append(x)
  86. return x_list, torch.LongTensor(y_list)
  87. def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):
  88. optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  89. criterion = nn.NLLLoss()
  90. dataloader = DataLoader(dataset,
  91. batch_size=batch_size,
  92. collate_fn=collate,
  93. num_workers=0)
  94. running_loss = 0.0
  95. if use_cuda:
  96. net.cuda()
  97. print('start training')
  98. for epoch in range(num_epoch):
  99. for i, batch_samples in enumerate(dataloader):
  100. x, y = batch_samples
  101. doc_list = []
  102. for sample in x:
  103. doc = []
  104. for sent_vec in sample:
  105. if use_cuda:
  106. sent_vec = sent_vec.cuda()
  107. doc.append(Variable(sent_vec))
  108. doc_list.append(pack_sequence(doc))
  109. if use_cuda:
  110. y = y.cuda()
  111. y = Variable(y)
  112. predict = net(doc_list)
  113. loss = criterion(predict, y)
  114. optimizer.zero_grad()
  115. loss.backward()
  116. optimizer.step()
  117. running_loss += loss.data[0]
  118. if i % print_size == print_size - 1:
  119. print('{}, {}'.format(i + 1, running_loss / print_size))
  120. running_loss = 0.0
  121. torch.save(net.state_dict(), 'models.dict')
  122. torch.save(net.state_dict(), 'models.dict')
  123. if __name__ == '__main__':
  124. '''
  125. Train process
  126. '''
  127. from gensim.models import Word2Vec
  128. from gensim import models
  129. train_word_vec()
  130. embed_model = Word2Vec.load('yelp.word2vec')
  131. embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
  132. del embed_model
  133. start_file = 0
  134. dataset = YelpDocSet('reviews', start_file, 120 - start_file, embedding)
  135. print('training data size {}'.format(len(dataset)))
  136. net = HAN(input_size=200, output_size=5,
  137. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  138. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  139. try:
  140. net.load_state_dict(torch.load('models.dict'))
  141. print("last time trained models has loaded")
  142. except Exception:
  143. print("cannot load models, train the inital models")
  144. train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True)