You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Encoder.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. from fastNLP.core.vocabulary import Vocabulary
  10. from fastNLP.io.embed_loader import EmbedLoader
  11. # from tools.logger import *
  12. from tools.PositionEmbedding import get_sinusoid_encoding_table
  13. WORD_PAD = "[PAD]"
  14. class Encoder(nn.Module):
  15. def __init__(self, hps, embed):
  16. """
  17. :param hps:
  18. word_emb_dim: word embedding dimension
  19. sent_max_len: max token number in the sentence
  20. output_channel: output channel for cnn
  21. min_kernel_size: min kernel size for cnn
  22. max_kernel_size: max kernel size for cnn
  23. word_embedding: bool, use word embedding or not
  24. embedding_path: word embedding path
  25. embed_train: bool, whether to train word embedding
  26. cuda: bool, use cuda or not
  27. :param vocab: FastNLP.Vocabulary
  28. """
  29. super(Encoder, self).__init__()
  30. self._hps = hps
  31. self.sent_max_len = hps.sent_max_len
  32. embed_size = hps.word_emb_dim
  33. sent_max_len = hps.sent_max_len
  34. input_channels = 1
  35. out_channels = hps.output_channel
  36. min_kernel_size = hps.min_kernel_size
  37. max_kernel_size = hps.max_kernel_size
  38. width = embed_size
  39. # word embedding
  40. self.embed = embed
  41. # position embedding
  42. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  43. # cnn
  44. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  45. print("[INFO] Initing W for CNN.......")
  46. for conv in self.convs:
  47. init_weight_value = 6.0
  48. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  49. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  50. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  51. def calculate_fan_in_and_fan_out(tensor):
  52. dimensions = tensor.ndimension()
  53. if dimensions < 2:
  54. print("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  55. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  56. if dimensions == 2: # Linear
  57. fan_in = tensor.size(1)
  58. fan_out = tensor.size(0)
  59. else:
  60. num_input_fmaps = tensor.size(1)
  61. num_output_fmaps = tensor.size(0)
  62. receptive_field_size = 1
  63. if tensor.dim() > 2:
  64. receptive_field_size = tensor[0][0].numel()
  65. fan_in = num_input_fmaps * receptive_field_size
  66. fan_out = num_output_fmaps * receptive_field_size
  67. return fan_in, fan_out
  68. def forward(self, input):
  69. # input: a batch of Example object [batch_size, N, seq_len]
  70. batch_size, N, _ = input.size()
  71. input = input.view(-1, input.size(2)) # [batch_size*N, L]
  72. input_sent_len = ((input!=0).sum(dim=1)).int() # [batch_size*N, 1]
  73. enc_embed_input = self.embed(input) # [batch_size*N, L, D]
  74. input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  75. if self._hps.cuda:
  76. input_pos = input_pos.cuda()
  77. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  78. # print(enc_embed_input.size())
  79. # print(enc_pos_embed_input.size())
  80. enc_conv_input = enc_embed_input + enc_pos_embed_input
  81. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  82. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  83. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  84. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  85. sent_embedding = sent_embedding.view(batch_size, N, -1)
  86. return sent_embedding
  87. class DomainEncoder(Encoder):
  88. def __init__(self, hps, vocab, domaindict):
  89. super(DomainEncoder, self).__init__(hps, vocab)
  90. # domain embedding
  91. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  92. self.domain_embedding.weight.requires_grad = True
  93. def forward(self, input, domain):
  94. """
  95. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  96. :param domain: [batch_size]
  97. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  98. """
  99. batch_size, N, _ = input.size()
  100. sent_embedding = super().forward(input)
  101. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  102. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  103. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  104. return sent_embedding
  105. class MultiDomainEncoder(Encoder):
  106. def __init__(self, hps, vocab, domaindict):
  107. super(MultiDomainEncoder, self).__init__(hps, vocab)
  108. self.domain_size = domaindict.size()
  109. # domain embedding
  110. self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim)
  111. self.domain_embedding.weight.requires_grad = True
  112. def forward(self, input, domain):
  113. """
  114. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  115. :param domain: [batch_size, domain_size]
  116. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  117. """
  118. batch_size, N, _ = input.size()
  119. # logger.info(domain[:5, :])
  120. sent_embedding = super().forward(input)
  121. domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1)
  122. domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size]
  123. enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D]
  124. enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D]
  125. # logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D]
  126. enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D]
  127. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  128. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  129. return sent_embedding
  130. class BertEncoder(nn.Module):
  131. def __init__(self, hps):
  132. super(BertEncoder, self).__init__()
  133. from pytorch_pretrained_bert.modeling import BertModel
  134. self._hps = hps
  135. self.sent_max_len = hps.sent_max_len
  136. self._cuda = hps.cuda
  137. embed_size = hps.word_emb_dim
  138. sent_max_len = hps.sent_max_len
  139. input_channels = 1
  140. out_channels = hps.output_channel
  141. min_kernel_size = hps.min_kernel_size
  142. max_kernel_size = hps.max_kernel_size
  143. width = embed_size
  144. # word embedding
  145. self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16")
  146. self._bert.eval()
  147. for p in self._bert.parameters():
  148. p.requires_grad = False
  149. self.word_embedding_proj = nn.Linear(4096, embed_size)
  150. # position embedding
  151. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  152. # cnn
  153. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  154. logger.info("[INFO] Initing W for CNN.......")
  155. for conv in self.convs:
  156. init_weight_value = 6.0
  157. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  158. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  159. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  160. def calculate_fan_in_and_fan_out(tensor):
  161. dimensions = tensor.ndimension()
  162. if dimensions < 2:
  163. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  164. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  165. if dimensions == 2: # Linear
  166. fan_in = tensor.size(1)
  167. fan_out = tensor.size(0)
  168. else:
  169. num_input_fmaps = tensor.size(1)
  170. num_output_fmaps = tensor.size(0)
  171. receptive_field_size = 1
  172. if tensor.dim() > 2:
  173. receptive_field_size = tensor[0][0].numel()
  174. fan_in = num_input_fmaps * receptive_field_size
  175. fan_out = num_output_fmaps * receptive_field_size
  176. return fan_in, fan_out
  177. def pad_encoder_input(self, input_list):
  178. """
  179. :param input_list: N [seq_len, hidden_state]
  180. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  181. """
  182. max_len = self.sent_max_len
  183. enc_sent_input_pad = []
  184. _, hidden_size = input_list[0].size()
  185. for i in range(len(input_list)):
  186. article_words = input_list[i] # [seq_len, hidden_size]
  187. seq_len = article_words.size(0)
  188. if seq_len > max_len:
  189. pad_words = article_words[:max_len, :]
  190. else:
  191. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  192. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  193. enc_sent_input_pad.append(pad_words)
  194. return enc_sent_input_pad
  195. def forward(self, inputs, input_masks, enc_sent_len):
  196. """
  197. :param inputs: a batch of Example object [batch_size, doc_len=512]
  198. :param input_masks: 0 or 1, [batch, doc_len=512]
  199. :param enc_sent_len: sentence original length [batch, N]
  200. :return:
  201. """
  202. # Use Bert to get word embedding
  203. batch_size, N = enc_sent_len.size()
  204. input_pad_list = []
  205. for i in range(batch_size):
  206. tokens_id = inputs[i]
  207. input_mask = input_masks[i]
  208. sent_len = enc_sent_len[i]
  209. input_ids = tokens_id.unsqueeze(0)
  210. input_mask = input_mask.unsqueeze(0)
  211. out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask)
  212. out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096]
  213. _, hidden_size = out.size()
  214. # restore the sentence
  215. last_end = 1
  216. enc_sent_input = []
  217. for length in sent_len:
  218. if length != 0 and last_end < 511:
  219. enc_sent_input.append(out[last_end: min(511, last_end + length), :])
  220. last_end += length
  221. else:
  222. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  223. enc_sent_input.append(pad_tensor)
  224. # pad the sentence
  225. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  226. input_pad_list.append(torch.stack(enc_sent_input_pad))
  227. input_pad = torch.stack(input_pad_list)
  228. input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1)
  229. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  230. enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D]
  231. sent_pos_list = []
  232. for sentlen in enc_sent_len:
  233. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  234. for k in range(self.sent_max_len - sentlen):
  235. sent_pos.append(0)
  236. sent_pos_list.append(sent_pos)
  237. input_pos = torch.Tensor(sent_pos_list).long()
  238. if self._hps.cuda:
  239. input_pos = input_pos.cuda()
  240. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  241. enc_conv_input = enc_embed_input + enc_pos_embed_input
  242. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  243. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  244. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  245. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  246. sent_embedding = sent_embedding.view(batch_size, N, -1)
  247. return sent_embedding
  248. class BertTagEncoder(BertEncoder):
  249. def __init__(self, hps, domaindict):
  250. super(BertTagEncoder, self).__init__(hps)
  251. # domain embedding
  252. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  253. self.domain_embedding.weight.requires_grad = True
  254. def forward(self, inputs, input_masks, enc_sent_len, domain):
  255. sent_embedding = super().forward(inputs, input_masks, enc_sent_len)
  256. batch_size, N = enc_sent_len.size()
  257. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  258. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  259. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  260. return sent_embedding
  261. class ELMoEndoer(nn.Module):
  262. def __init__(self, hps):
  263. super(ELMoEndoer, self).__init__()
  264. self._hps = hps
  265. self.sent_max_len = hps.sent_max_len
  266. from allennlp.modules.elmo import Elmo
  267. elmo_dim = 1024
  268. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  269. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  270. # elmo_dim = 512
  271. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  272. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  273. embed_size = hps.word_emb_dim
  274. sent_max_len = hps.sent_max_len
  275. input_channels = 1
  276. out_channels = hps.output_channel
  277. min_kernel_size = hps.min_kernel_size
  278. max_kernel_size = hps.max_kernel_size
  279. width = embed_size
  280. # elmo embedding
  281. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  282. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  283. # position embedding
  284. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  285. # cnn
  286. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  287. logger.info("[INFO] Initing W for CNN.......")
  288. for conv in self.convs:
  289. init_weight_value = 6.0
  290. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  291. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  292. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  293. def calculate_fan_in_and_fan_out(tensor):
  294. dimensions = tensor.ndimension()
  295. if dimensions < 2:
  296. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  297. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  298. if dimensions == 2: # Linear
  299. fan_in = tensor.size(1)
  300. fan_out = tensor.size(0)
  301. else:
  302. num_input_fmaps = tensor.size(1)
  303. num_output_fmaps = tensor.size(0)
  304. receptive_field_size = 1
  305. if tensor.dim() > 2:
  306. receptive_field_size = tensor[0][0].numel()
  307. fan_in = num_input_fmaps * receptive_field_size
  308. fan_out = num_output_fmaps * receptive_field_size
  309. return fan_in, fan_out
  310. def forward(self, input):
  311. # input: a batch of Example object [batch_size, N, seq_len, character_len]
  312. batch_size, N, seq_len, _ = input.size()
  313. input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len]
  314. input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1]
  315. # logger.debug(input_sent_len.view(batch_size, -1))
  316. enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D]
  317. enc_embed_input = self.embed_proj(enc_embed_input)
  318. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  319. sent_pos_list = []
  320. for sentlen in input_sent_len:
  321. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  322. for k in range(self.sent_max_len - sentlen):
  323. sent_pos.append(0)
  324. sent_pos_list.append(sent_pos)
  325. input_pos = torch.Tensor(sent_pos_list).long()
  326. if self._hps.cuda:
  327. input_pos = input_pos.cuda()
  328. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  329. enc_conv_input = enc_embed_input + enc_pos_embed_input
  330. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  331. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  332. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  333. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  334. sent_embedding = sent_embedding.view(batch_size, N, -1)
  335. return sent_embedding
  336. class ELMoEndoer2(nn.Module):
  337. def __init__(self, hps):
  338. super(ELMoEndoer2, self).__init__()
  339. self._hps = hps
  340. self._cuda = hps.cuda
  341. self.sent_max_len = hps.sent_max_len
  342. from allennlp.modules.elmo import Elmo
  343. elmo_dim = 1024
  344. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  345. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  346. # elmo_dim = 512
  347. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  348. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  349. embed_size = hps.word_emb_dim
  350. sent_max_len = hps.sent_max_len
  351. input_channels = 1
  352. out_channels = hps.output_channel
  353. min_kernel_size = hps.min_kernel_size
  354. max_kernel_size = hps.max_kernel_size
  355. width = embed_size
  356. # elmo embedding
  357. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  358. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  359. # position embedding
  360. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  361. # cnn
  362. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  363. logger.info("[INFO] Initing W for CNN.......")
  364. for conv in self.convs:
  365. init_weight_value = 6.0
  366. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  367. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  368. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  369. def calculate_fan_in_and_fan_out(tensor):
  370. dimensions = tensor.ndimension()
  371. if dimensions < 2:
  372. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  373. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  374. if dimensions == 2: # Linear
  375. fan_in = tensor.size(1)
  376. fan_out = tensor.size(0)
  377. else:
  378. num_input_fmaps = tensor.size(1)
  379. num_output_fmaps = tensor.size(0)
  380. receptive_field_size = 1
  381. if tensor.dim() > 2:
  382. receptive_field_size = tensor[0][0].numel()
  383. fan_in = num_input_fmaps * receptive_field_size
  384. fan_out = num_output_fmaps * receptive_field_size
  385. return fan_in, fan_out
  386. def pad_encoder_input(self, input_list):
  387. """
  388. :param input_list: N [seq_len, hidden_state]
  389. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  390. """
  391. max_len = self.sent_max_len
  392. enc_sent_input_pad = []
  393. _, hidden_size = input_list[0].size()
  394. for i in range(len(input_list)):
  395. article_words = input_list[i] # [seq_len, hidden_size]
  396. seq_len = article_words.size(0)
  397. if seq_len > max_len:
  398. pad_words = article_words[:max_len, :]
  399. else:
  400. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  401. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  402. enc_sent_input_pad.append(pad_words)
  403. return enc_sent_input_pad
  404. def forward(self, inputs, input_masks, enc_sent_len):
  405. """
  406. :param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50]
  407. :param input_masks: 0 or 1, [batch, doc_len=512]
  408. :param enc_sent_len: sentence original length [batch, N]
  409. :return:
  410. sent_embedding: [batch, N, D]
  411. """
  412. # Use Bert to get word embedding
  413. batch_size, N = enc_sent_len.size()
  414. input_pad_list = []
  415. elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D]
  416. elmo_output = elmo_output * input_masks.unsqueeze(-1).float()
  417. # print("END elmo")
  418. for i in range(batch_size):
  419. sent_len = enc_sent_len[i] # [1, N]
  420. out = elmo_output[i]
  421. _, hidden_size = out.size()
  422. # restore the sentence
  423. last_end = 0
  424. enc_sent_input = []
  425. for length in sent_len:
  426. if length != 0 and last_end < 512:
  427. enc_sent_input.append(out[last_end : min(512, last_end + length), :])
  428. last_end += length
  429. else:
  430. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  431. enc_sent_input.append(pad_tensor)
  432. # pad the sentence
  433. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  434. input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state]
  435. input_pad = torch.stack(input_pad_list)
  436. input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1)
  437. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  438. enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D]
  439. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  440. sent_pos_list = []
  441. for sentlen in enc_sent_len:
  442. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  443. for k in range(self.sent_max_len - sentlen):
  444. sent_pos.append(0)
  445. sent_pos_list.append(sent_pos)
  446. input_pos = torch.Tensor(sent_pos_list).long()
  447. if self._hps.cuda:
  448. input_pos = input_pos.cuda()
  449. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  450. enc_conv_input = enc_embed_input + enc_pos_embed_input
  451. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  452. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  453. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  454. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  455. sent_embedding = sent_embedding.view(batch_size, N, -1)
  456. return sent_embedding