You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Encoder.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch.autograd import *
  9. import torch.nn.init as init
  10. import data
  11. from tools.logger import *
  12. from transformer.Models import get_sinusoid_encoding_table
  13. class Encoder(nn.Module):
  14. def __init__(self, hps, vocab):
  15. super(Encoder, self).__init__()
  16. self._hps = hps
  17. self._vocab = vocab
  18. self.sent_max_len = hps.sent_max_len
  19. vocab_size = len(vocab)
  20. logger.info("[INFO] Vocabulary size is %d", vocab_size)
  21. embed_size = hps.word_emb_dim
  22. sent_max_len = hps.sent_max_len
  23. input_channels = 1
  24. out_channels = hps.output_channel
  25. min_kernel_size = hps.min_kernel_size
  26. max_kernel_size = hps.max_kernel_size
  27. width = embed_size
  28. # word embedding
  29. self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=vocab.word2id('[PAD]'))
  30. if hps.word_embedding:
  31. word2vec = data.Word_Embedding(hps.embedding_path, vocab)
  32. word_vecs = word2vec.load_my_vecs(embed_size)
  33. # pretrained_weight = word2vec.add_unknown_words_by_zero(word_vecs, embed_size)
  34. pretrained_weight = word2vec.add_unknown_words_by_avg(word_vecs, embed_size)
  35. pretrained_weight = np.array(pretrained_weight)
  36. self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
  37. self.embed.weight.requires_grad = hps.embed_train
  38. # position embedding
  39. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  40. # cnn
  41. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  42. logger.info("[INFO] Initing W for CNN.......")
  43. for conv in self.convs:
  44. init_weight_value = 6.0
  45. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  46. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  47. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  48. def calculate_fan_in_and_fan_out(tensor):
  49. dimensions = tensor.ndimension()
  50. if dimensions < 2:
  51. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  52. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  53. if dimensions == 2: # Linear
  54. fan_in = tensor.size(1)
  55. fan_out = tensor.size(0)
  56. else:
  57. num_input_fmaps = tensor.size(1)
  58. num_output_fmaps = tensor.size(0)
  59. receptive_field_size = 1
  60. if tensor.dim() > 2:
  61. receptive_field_size = tensor[0][0].numel()
  62. fan_in = num_input_fmaps * receptive_field_size
  63. fan_out = num_output_fmaps * receptive_field_size
  64. return fan_in, fan_out
  65. def forward(self, input):
  66. # input: a batch of Example object [batch_size, N, seq_len]
  67. vocab = self._vocab
  68. batch_size, N, _ = input.size()
  69. input = input.view(-1, input.size(2)) # [batch_size*N, L]
  70. input_sent_len = ((input!=vocab.word2id('[PAD]')).sum(dim=1)).int() # [batch_size*N, 1]
  71. enc_embed_input = self.embed(input) # [batch_size*N, L, D]
  72. input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  73. if self._hps.cuda:
  74. input_pos = input_pos.cuda()
  75. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  76. enc_conv_input = enc_embed_input + enc_pos_embed_input
  77. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  78. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  79. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  80. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  81. sent_embedding = sent_embedding.view(batch_size, N, -1)
  82. return sent_embedding
  83. class DomainEncoder(Encoder):
  84. def __init__(self, hps, vocab, domaindict):
  85. super(DomainEncoder, self).__init__(hps, vocab)
  86. # domain embedding
  87. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  88. self.domain_embedding.weight.requires_grad = True
  89. def forward(self, input, domain):
  90. """
  91. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  92. :param domain: [batch_size]
  93. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  94. """
  95. batch_size, N, _ = input.size()
  96. sent_embedding = super().forward(input)
  97. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  98. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  99. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  100. return sent_embedding
  101. class MultiDomainEncoder(Encoder):
  102. def __init__(self, hps, vocab, domaindict):
  103. super(MultiDomainEncoder, self).__init__(hps, vocab)
  104. self.domain_size = domaindict.size()
  105. # domain embedding
  106. self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim)
  107. self.domain_embedding.weight.requires_grad = True
  108. def forward(self, input, domain):
  109. """
  110. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  111. :param domain: [batch_size, domain_size]
  112. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  113. """
  114. batch_size, N, _ = input.size()
  115. # logger.info(domain[:5, :])
  116. sent_embedding = super().forward(input)
  117. domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1)
  118. domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size]
  119. enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D]
  120. enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D]
  121. # logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D]
  122. enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D]
  123. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  124. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  125. return sent_embedding
  126. class BertEncoder(nn.Module):
  127. def __init__(self, hps):
  128. super(BertEncoder, self).__init__()
  129. from pytorch_pretrained_bert.modeling import BertModel
  130. self._hps = hps
  131. self.sent_max_len = hps.sent_max_len
  132. self._cuda = hps.cuda
  133. embed_size = hps.word_emb_dim
  134. sent_max_len = hps.sent_max_len
  135. input_channels = 1
  136. out_channels = hps.output_channel
  137. min_kernel_size = hps.min_kernel_size
  138. max_kernel_size = hps.max_kernel_size
  139. width = embed_size
  140. # word embedding
  141. self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16")
  142. self._bert.eval()
  143. for p in self._bert.parameters():
  144. p.requires_grad = False
  145. self.word_embedding_proj = nn.Linear(4096, embed_size)
  146. # position embedding
  147. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  148. # cnn
  149. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  150. logger.info("[INFO] Initing W for CNN.......")
  151. for conv in self.convs:
  152. init_weight_value = 6.0
  153. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  154. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  155. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  156. def calculate_fan_in_and_fan_out(tensor):
  157. dimensions = tensor.ndimension()
  158. if dimensions < 2:
  159. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  160. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  161. if dimensions == 2: # Linear
  162. fan_in = tensor.size(1)
  163. fan_out = tensor.size(0)
  164. else:
  165. num_input_fmaps = tensor.size(1)
  166. num_output_fmaps = tensor.size(0)
  167. receptive_field_size = 1
  168. if tensor.dim() > 2:
  169. receptive_field_size = tensor[0][0].numel()
  170. fan_in = num_input_fmaps * receptive_field_size
  171. fan_out = num_output_fmaps * receptive_field_size
  172. return fan_in, fan_out
  173. def pad_encoder_input(self, input_list):
  174. """
  175. :param input_list: N [seq_len, hidden_state]
  176. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  177. """
  178. max_len = self.sent_max_len
  179. enc_sent_input_pad = []
  180. _, hidden_size = input_list[0].size()
  181. for i in range(len(input_list)):
  182. article_words = input_list[i] # [seq_len, hidden_size]
  183. seq_len = article_words.size(0)
  184. if seq_len > max_len:
  185. pad_words = article_words[:max_len, :]
  186. else:
  187. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  188. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  189. enc_sent_input_pad.append(pad_words)
  190. return enc_sent_input_pad
  191. def forward(self, inputs, input_masks, enc_sent_len):
  192. """
  193. :param inputs: a batch of Example object [batch_size, doc_len=512]
  194. :param input_masks: 0 or 1, [batch, doc_len=512]
  195. :param enc_sent_len: sentence original length [batch, N]
  196. :return:
  197. """
  198. # Use Bert to get word embedding
  199. batch_size, N = enc_sent_len.size()
  200. input_pad_list = []
  201. for i in range(batch_size):
  202. tokens_id = inputs[i]
  203. input_mask = input_masks[i]
  204. sent_len = enc_sent_len[i]
  205. input_ids = tokens_id.unsqueeze(0)
  206. input_mask = input_mask.unsqueeze(0)
  207. out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask)
  208. out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096]
  209. _, hidden_size = out.size()
  210. # restore the sentence
  211. last_end = 1
  212. enc_sent_input = []
  213. for length in sent_len:
  214. if length != 0 and last_end < 511:
  215. enc_sent_input.append(out[last_end: min(511, last_end + length), :])
  216. last_end += length
  217. else:
  218. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  219. enc_sent_input.append(pad_tensor)
  220. # pad the sentence
  221. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  222. input_pad_list.append(torch.stack(enc_sent_input_pad))
  223. input_pad = torch.stack(input_pad_list)
  224. input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1)
  225. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  226. enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D]
  227. sent_pos_list = []
  228. for sentlen in enc_sent_len:
  229. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  230. for k in range(self.sent_max_len - sentlen):
  231. sent_pos.append(0)
  232. sent_pos_list.append(sent_pos)
  233. input_pos = torch.Tensor(sent_pos_list).long()
  234. if self._hps.cuda:
  235. input_pos = input_pos.cuda()
  236. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  237. enc_conv_input = enc_embed_input + enc_pos_embed_input
  238. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  239. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  240. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  241. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  242. sent_embedding = sent_embedding.view(batch_size, N, -1)
  243. return sent_embedding
  244. class BertTagEncoder(BertEncoder):
  245. def __init__(self, hps, domaindict):
  246. super(BertTagEncoder, self).__init__(hps)
  247. # domain embedding
  248. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  249. self.domain_embedding.weight.requires_grad = True
  250. def forward(self, inputs, input_masks, enc_sent_len, domain):
  251. sent_embedding = super().forward(inputs, input_masks, enc_sent_len)
  252. batch_size, N = enc_sent_len.size()
  253. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  254. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  255. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  256. return sent_embedding
  257. class ELMoEndoer(nn.Module):
  258. def __init__(self, hps):
  259. super(ELMoEndoer, self).__init__()
  260. self._hps = hps
  261. self.sent_max_len = hps.sent_max_len
  262. from allennlp.modules.elmo import Elmo
  263. elmo_dim = 1024
  264. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  265. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  266. # elmo_dim = 512
  267. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  268. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  269. embed_size = hps.word_emb_dim
  270. sent_max_len = hps.sent_max_len
  271. input_channels = 1
  272. out_channels = hps.output_channel
  273. min_kernel_size = hps.min_kernel_size
  274. max_kernel_size = hps.max_kernel_size
  275. width = embed_size
  276. # elmo embedding
  277. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  278. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  279. # position embedding
  280. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  281. # cnn
  282. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  283. logger.info("[INFO] Initing W for CNN.......")
  284. for conv in self.convs:
  285. init_weight_value = 6.0
  286. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  287. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  288. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  289. def calculate_fan_in_and_fan_out(tensor):
  290. dimensions = tensor.ndimension()
  291. if dimensions < 2:
  292. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  293. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  294. if dimensions == 2: # Linear
  295. fan_in = tensor.size(1)
  296. fan_out = tensor.size(0)
  297. else:
  298. num_input_fmaps = tensor.size(1)
  299. num_output_fmaps = tensor.size(0)
  300. receptive_field_size = 1
  301. if tensor.dim() > 2:
  302. receptive_field_size = tensor[0][0].numel()
  303. fan_in = num_input_fmaps * receptive_field_size
  304. fan_out = num_output_fmaps * receptive_field_size
  305. return fan_in, fan_out
  306. def forward(self, input):
  307. # input: a batch of Example object [batch_size, N, seq_len, character_len]
  308. batch_size, N, seq_len, _ = input.size()
  309. input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len]
  310. input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1]
  311. logger.debug(input_sent_len.view(batch_size, -1))
  312. enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D]
  313. enc_embed_input = self.embed_proj(enc_embed_input)
  314. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  315. sent_pos_list = []
  316. for sentlen in input_sent_len:
  317. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  318. for k in range(self.sent_max_len - sentlen):
  319. sent_pos.append(0)
  320. sent_pos_list.append(sent_pos)
  321. input_pos = torch.Tensor(sent_pos_list).long()
  322. if self._hps.cuda:
  323. input_pos = input_pos.cuda()
  324. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  325. enc_conv_input = enc_embed_input + enc_pos_embed_input
  326. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  327. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  328. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  329. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  330. sent_embedding = sent_embedding.view(batch_size, N, -1)
  331. return sent_embedding
  332. class ELMoEndoer2(nn.Module):
  333. def __init__(self, hps):
  334. super(ELMoEndoer2, self).__init__()
  335. self._hps = hps
  336. self._cuda = hps.cuda
  337. self.sent_max_len = hps.sent_max_len
  338. from allennlp.modules.elmo import Elmo
  339. elmo_dim = 1024
  340. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  341. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  342. # elmo_dim = 512
  343. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  344. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  345. embed_size = hps.word_emb_dim
  346. sent_max_len = hps.sent_max_len
  347. input_channels = 1
  348. out_channels = hps.output_channel
  349. min_kernel_size = hps.min_kernel_size
  350. max_kernel_size = hps.max_kernel_size
  351. width = embed_size
  352. # elmo embedding
  353. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  354. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  355. # position embedding
  356. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  357. # cnn
  358. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  359. logger.info("[INFO] Initing W for CNN.......")
  360. for conv in self.convs:
  361. init_weight_value = 6.0
  362. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  363. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  364. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  365. def calculate_fan_in_and_fan_out(tensor):
  366. dimensions = tensor.ndimension()
  367. if dimensions < 2:
  368. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  369. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  370. if dimensions == 2: # Linear
  371. fan_in = tensor.size(1)
  372. fan_out = tensor.size(0)
  373. else:
  374. num_input_fmaps = tensor.size(1)
  375. num_output_fmaps = tensor.size(0)
  376. receptive_field_size = 1
  377. if tensor.dim() > 2:
  378. receptive_field_size = tensor[0][0].numel()
  379. fan_in = num_input_fmaps * receptive_field_size
  380. fan_out = num_output_fmaps * receptive_field_size
  381. return fan_in, fan_out
  382. def pad_encoder_input(self, input_list):
  383. """
  384. :param input_list: N [seq_len, hidden_state]
  385. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  386. """
  387. max_len = self.sent_max_len
  388. enc_sent_input_pad = []
  389. _, hidden_size = input_list[0].size()
  390. for i in range(len(input_list)):
  391. article_words = input_list[i] # [seq_len, hidden_size]
  392. seq_len = article_words.size(0)
  393. if seq_len > max_len:
  394. pad_words = article_words[:max_len, :]
  395. else:
  396. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  397. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  398. enc_sent_input_pad.append(pad_words)
  399. return enc_sent_input_pad
  400. def forward(self, inputs, input_masks, enc_sent_len):
  401. """
  402. :param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50]
  403. :param input_masks: 0 or 1, [batch, doc_len=512]
  404. :param enc_sent_len: sentence original length [batch, N]
  405. :return:
  406. sent_embedding: [batch, N, D]
  407. """
  408. # Use Bert to get word embedding
  409. batch_size, N = enc_sent_len.size()
  410. input_pad_list = []
  411. elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D]
  412. elmo_output = elmo_output * input_masks.unsqueeze(-1).float()
  413. # print("END elmo")
  414. for i in range(batch_size):
  415. sent_len = enc_sent_len[i] # [1, N]
  416. out = elmo_output[i]
  417. _, hidden_size = out.size()
  418. # restore the sentence
  419. last_end = 0
  420. enc_sent_input = []
  421. for length in sent_len:
  422. if length != 0 and last_end < 512:
  423. enc_sent_input.append(out[last_end : min(512, last_end + length), :])
  424. last_end += length
  425. else:
  426. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  427. enc_sent_input.append(pad_tensor)
  428. # pad the sentence
  429. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  430. input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state]
  431. input_pad = torch.stack(input_pad_list)
  432. input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1)
  433. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  434. enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D]
  435. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  436. sent_pos_list = []
  437. for sentlen in enc_sent_len:
  438. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  439. for k in range(self.sent_max_len - sentlen):
  440. sent_pos.append(0)
  441. sent_pos_list.append(sent_pos)
  442. input_pos = torch.Tensor(sent_pos_list).long()
  443. if self._hps.cuda:
  444. input_pos = input_pos.cuda()
  445. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  446. enc_conv_input = enc_embed_input + enc_pos_embed_input
  447. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  448. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  449. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  450. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  451. sent_embedding = sent_embedding.view(batch_size, N, -1)
  452. return sent_embedding