| @@ -0,0 +1,61 @@ | |||||
| from model import * | |||||
| from train import * | |||||
| def evaluate(net, dataset, bactch_size=64, use_cuda=False): | |||||
| dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0) | |||||
| count = 0 | |||||
| if use_cuda: | |||||
| net.cuda() | |||||
| for i, batch_samples in enumerate(dataloader): | |||||
| x, y = batch_samples | |||||
| doc_list = [] | |||||
| for sample in x: | |||||
| doc = [] | |||||
| for sent_vec in sample: | |||||
| # print(sent_vec.size()) | |||||
| if use_cuda: | |||||
| sent_vec = sent_vec.cuda() | |||||
| doc.append(Variable(sent_vec, volatile=True)) | |||||
| doc_list.append(pack_sequence(doc)) | |||||
| if use_cuda: | |||||
| y = y.cuda() | |||||
| predicts = net(doc_list) | |||||
| # idx = [] | |||||
| # for p in predicts.data: | |||||
| # idx.append(np.random.choice(5, p=torch.exp(p).numpy())) | |||||
| # idx = torch.LongTensor(idx) | |||||
| p, idx = torch.max(predicts, dim=1) | |||||
| idx = idx.data | |||||
| count += torch.sum(torch.eq(idx, y)) | |||||
| return count | |||||
| def visualize_attention(doc, alpha_vec): | |||||
| pass | |||||
| if __name__ == '__main__': | |||||
| from gensim.models import Word2Vec | |||||
| import gensim | |||||
| from gensim import models | |||||
| embed_model = Word2Vec.load('yelp.word2vec') | |||||
| embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | |||||
| del embed_model | |||||
| net = HAN(input_size=200, output_size=5, | |||||
| word_hidden_size=50, word_num_layers=1, word_context_size=100, | |||||
| sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | |||||
| net.load_state_dict(torch.load('model.dict')) | |||||
| test_dataset = YelpDocSet('reviews', 199, 4, embedding) | |||||
| correct = evaluate(net, test_dataset, True) | |||||
| print('accuracy {}'.format(correct/len(test_dataset))) | |||||
| # data_idx = 121 | |||||
| # x, y = test_dataset[data_idx] | |||||
| # doc = [] | |||||
| # for sent_vec in x: | |||||
| # doc.append(Variable(sent_vec, volatile=True)) | |||||
| # input_vec = [pack_sequence(doc)] | |||||
| # predict = net(input_vec) | |||||
| # p, idx = torch.max(predict, dim=1) | |||||
| # print(net.word_layer.last_alpha.squeeze()) | |||||
| # print(net.sent_layer.last_alpha) | |||||
| # print(test_dataset.get_doc(data_idx)[0]) | |||||
| # print('predict: {}, true: {}'.format(int(idx), y)) | |||||
| @@ -55,6 +55,7 @@ class AttentionNet(nn.Module): | |||||
| self.gru_hidden_size = gru_hidden_size | self.gru_hidden_size = gru_hidden_size | ||||
| self.gru_num_layers = gru_num_layers | self.gru_num_layers = gru_num_layers | ||||
| self.context_vec_size = context_vec_size | self.context_vec_size = context_vec_size | ||||
| self.last_alpha = None | |||||
| # Encoder | # Encoder | ||||
| self.gru = nn.GRU(input_size=input_size, | self.gru = nn.GRU(input_size=input_size, | ||||
| @@ -76,6 +77,7 @@ class AttentionNet(nn.Module): | |||||
| u = self.tanh(self.fc(h_t)) | u = self.tanh(self.fc(h_t)) | ||||
| # u's dim (batch_size, seq_len, context_vec_size) | # u's dim (batch_size, seq_len, context_vec_size) | ||||
| alpha = self.softmax(torch.matmul(u, self.context_vec)) | alpha = self.softmax(torch.matmul(u, self.context_vec)) | ||||
| self.last_alpha = alpha.data | |||||
| # alpha's dim (batch_size, seq_len, 1) | # alpha's dim (batch_size, seq_len, 1) | ||||
| output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) | output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) | ||||
| # output's dim (batch_size, 2*hidden_size, 1) | # output's dim (batch_size, 2*hidden_size, 1) | ||||
| @@ -78,6 +78,20 @@ class YelpDocSet(Dataset): | |||||
| self.embedding = embedding | self.embedding = embedding | ||||
| self._cache = [(-1, None) for i in range(5)] | self._cache = [(-1, None) for i in range(5)] | ||||
| def get_doc(self, n): | |||||
| file_id = n // 5000 | |||||
| idx = file_id % 5 | |||||
| if self._cache[idx][0] != file_id: | |||||
| print('load {} to {}'.format(file_id, idx)) | |||||
| with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f: | |||||
| self._cache[idx] = (file_id, pickle.load(f)) | |||||
| y, x = self._cache[idx][1][n % 5000] | |||||
| sents = [] | |||||
| for s_list in x: | |||||
| sents.append(' '.join(s_list)) | |||||
| x = '\n'.join(sents) | |||||
| return x, y-1 | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self._files)*5000 | return len(self._files)*5000 | ||||
| @@ -166,7 +180,6 @@ if __name__ == '__main__': | |||||
| embed_model = Word2Vec.load('yelp.word2vec') | embed_model = Word2Vec.load('yelp.word2vec') | ||||
| embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | ||||
| del embed_model | del embed_model | ||||
| # for start_file in range(11, 24): | |||||
| start_file = 0 | start_file = 0 | ||||
| dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding) | dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding) | ||||
| print('start_file %d'% start_file) | print('start_file %d'% start_file) | ||||
| @@ -176,4 +189,4 @@ if __name__ == '__main__': | |||||
| sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | ||||
| net.load_state_dict(torch.load('model.dict')) | net.load_state_dict(torch.load('model.dict')) | ||||
| train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True) | |||||
| train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True) | |||||