You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Encoder.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.nn.init as init
  9. from fastNLP.core.vocabulary import Vocabulary
  10. from fastNLP.io.embed_loader import EmbedLoader
  11. # from tools.logger import *
  12. from tools.PositionEmbedding import get_sinusoid_encoding_table
  13. WORD_PAD = "[PAD]"
  14. class Encoder(nn.Module):
  15. def __init__(self, hps, embed):
  16. """
  17. :param hps:
  18. word_emb_dim: word embedding dimension
  19. sent_max_len: max token number in the sentence
  20. output_channel: output channel for cnn
  21. min_kernel_size: min kernel size for cnn
  22. max_kernel_size: max kernel size for cnn
  23. word_embedding: bool, use word embedding or not
  24. embedding_path: word embedding path
  25. embed_train: bool, whether to train word embedding
  26. cuda: bool, use cuda or not
  27. :param vocab: FastNLP.Vocabulary
  28. """
  29. super(Encoder, self).__init__()
  30. self._hps = hps
  31. self.sent_max_len = hps.sent_max_len
  32. embed_size = hps.word_emb_dim
  33. sent_max_len = hps.sent_max_len
  34. input_channels = 1
  35. out_channels = hps.output_channel
  36. min_kernel_size = hps.min_kernel_size
  37. max_kernel_size = hps.max_kernel_size
  38. width = embed_size
  39. # word embedding
  40. self.embed = embed
  41. # position embedding
  42. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  43. # cnn
  44. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  45. print("[INFO] Initing W for CNN.......")
  46. for conv in self.convs:
  47. init_weight_value = 6.0
  48. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  49. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  50. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  51. def calculate_fan_in_and_fan_out(tensor):
  52. dimensions = tensor.ndimension()
  53. if dimensions < 2:
  54. print("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  55. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  56. if dimensions == 2: # Linear
  57. fan_in = tensor.size(1)
  58. fan_out = tensor.size(0)
  59. else:
  60. num_input_fmaps = tensor.size(1)
  61. num_output_fmaps = tensor.size(0)
  62. receptive_field_size = 1
  63. if tensor.dim() > 2:
  64. receptive_field_size = tensor[0][0].numel()
  65. fan_in = num_input_fmaps * receptive_field_size
  66. fan_out = num_output_fmaps * receptive_field_size
  67. return fan_in, fan_out
  68. def forward(self, input):
  69. # input: a batch of Example object [batch_size, N, seq_len]
  70. batch_size, N, _ = input.size()
  71. input = input.view(-1, input.size(2)) # [batch_size*N, L]
  72. input_sent_len = ((input!=0).sum(dim=1)).int() # [batch_size*N, 1]
  73. enc_embed_input = self.embed(input) # [batch_size*N, L, D]
  74. input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  75. if self._hps.cuda:
  76. input_pos = input_pos.cuda()
  77. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  78. enc_conv_input = enc_embed_input + enc_pos_embed_input
  79. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  80. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  81. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  82. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  83. sent_embedding = sent_embedding.view(batch_size, N, -1)
  84. return sent_embedding
  85. class DomainEncoder(Encoder):
  86. def __init__(self, hps, vocab, domaindict):
  87. super(DomainEncoder, self).__init__(hps, vocab)
  88. # domain embedding
  89. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  90. self.domain_embedding.weight.requires_grad = True
  91. def forward(self, input, domain):
  92. """
  93. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  94. :param domain: [batch_size]
  95. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  96. """
  97. batch_size, N, _ = input.size()
  98. sent_embedding = super().forward(input)
  99. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  100. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  101. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  102. return sent_embedding
  103. class MultiDomainEncoder(Encoder):
  104. def __init__(self, hps, vocab, domaindict):
  105. super(MultiDomainEncoder, self).__init__(hps, vocab)
  106. self.domain_size = domaindict.size()
  107. # domain embedding
  108. self.domain_embedding = nn.Embedding(self.domain_size, hps.domain_emb_dim)
  109. self.domain_embedding.weight.requires_grad = True
  110. def forward(self, input, domain):
  111. """
  112. :param input: [batch_size, N, seq_len], N sentence number, seq_len token number
  113. :param domain: [batch_size, domain_size]
  114. :return: sent_embedding: [batch_size, N, Co * kernel_sizes]
  115. """
  116. batch_size, N, _ = input.size()
  117. # logger.info(domain[:5, :])
  118. sent_embedding = super().forward(input)
  119. domain_padding = torch.arange(self.domain_size).unsqueeze(0).expand(batch_size, -1)
  120. domain_padding = domain_padding.cuda().view(-1) if self._hps.cuda else domain_padding.view(-1) # [batch * domain_size]
  121. enc_domain_input = self.domain_embedding(domain_padding) # [batch * domain_size, D]
  122. enc_domain_input = enc_domain_input.view(batch_size, self.domain_size, -1) * domain.unsqueeze(-1).float() # [batch, domain_size, D]
  123. # logger.info(enc_domain_input[:5,:]) # [batch, domain_size, D]
  124. enc_domain_input = enc_domain_input.sum(1) / domain.sum(1).float().unsqueeze(-1) # [batch, D]
  125. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  126. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  127. return sent_embedding
  128. class BertEncoder(nn.Module):
  129. def __init__(self, hps):
  130. super(BertEncoder, self).__init__()
  131. from pytorch_pretrained_bert.modeling import BertModel
  132. self._hps = hps
  133. self.sent_max_len = hps.sent_max_len
  134. self._cuda = hps.cuda
  135. embed_size = hps.word_emb_dim
  136. sent_max_len = hps.sent_max_len
  137. input_channels = 1
  138. out_channels = hps.output_channel
  139. min_kernel_size = hps.min_kernel_size
  140. max_kernel_size = hps.max_kernel_size
  141. width = embed_size
  142. # word embedding
  143. self._bert = BertModel.from_pretrained("/remote-home/dqwang/BERT/pre-train/uncased_L-24_H-1024_A-16")
  144. self._bert.eval()
  145. for p in self._bert.parameters():
  146. p.requires_grad = False
  147. self.word_embedding_proj = nn.Linear(4096, embed_size)
  148. # position embedding
  149. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  150. # cnn
  151. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  152. logger.info("[INFO] Initing W for CNN.......")
  153. for conv in self.convs:
  154. init_weight_value = 6.0
  155. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  156. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  157. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  158. def calculate_fan_in_and_fan_out(tensor):
  159. dimensions = tensor.ndimension()
  160. if dimensions < 2:
  161. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  162. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  163. if dimensions == 2: # Linear
  164. fan_in = tensor.size(1)
  165. fan_out = tensor.size(0)
  166. else:
  167. num_input_fmaps = tensor.size(1)
  168. num_output_fmaps = tensor.size(0)
  169. receptive_field_size = 1
  170. if tensor.dim() > 2:
  171. receptive_field_size = tensor[0][0].numel()
  172. fan_in = num_input_fmaps * receptive_field_size
  173. fan_out = num_output_fmaps * receptive_field_size
  174. return fan_in, fan_out
  175. def pad_encoder_input(self, input_list):
  176. """
  177. :param input_list: N [seq_len, hidden_state]
  178. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  179. """
  180. max_len = self.sent_max_len
  181. enc_sent_input_pad = []
  182. _, hidden_size = input_list[0].size()
  183. for i in range(len(input_list)):
  184. article_words = input_list[i] # [seq_len, hidden_size]
  185. seq_len = article_words.size(0)
  186. if seq_len > max_len:
  187. pad_words = article_words[:max_len, :]
  188. else:
  189. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  190. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  191. enc_sent_input_pad.append(pad_words)
  192. return enc_sent_input_pad
  193. def forward(self, inputs, input_masks, enc_sent_len):
  194. """
  195. :param inputs: a batch of Example object [batch_size, doc_len=512]
  196. :param input_masks: 0 or 1, [batch, doc_len=512]
  197. :param enc_sent_len: sentence original length [batch, N]
  198. :return:
  199. """
  200. # Use Bert to get word embedding
  201. batch_size, N = enc_sent_len.size()
  202. input_pad_list = []
  203. for i in range(batch_size):
  204. tokens_id = inputs[i]
  205. input_mask = input_masks[i]
  206. sent_len = enc_sent_len[i]
  207. input_ids = tokens_id.unsqueeze(0)
  208. input_mask = input_mask.unsqueeze(0)
  209. out, _ = self._bert(input_ids, token_type_ids=None, attention_mask=input_mask)
  210. out = torch.cat(out[-4:], dim=-1).squeeze(0) # [doc_len=512, hidden_state=4096]
  211. _, hidden_size = out.size()
  212. # restore the sentence
  213. last_end = 1
  214. enc_sent_input = []
  215. for length in sent_len:
  216. if length != 0 and last_end < 511:
  217. enc_sent_input.append(out[last_end: min(511, last_end + length), :])
  218. last_end += length
  219. else:
  220. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  221. enc_sent_input.append(pad_tensor)
  222. # pad the sentence
  223. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  224. input_pad_list.append(torch.stack(enc_sent_input_pad))
  225. input_pad = torch.stack(input_pad_list)
  226. input_pad = input_pad.view(batch_size*N, self.sent_max_len, -1)
  227. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  228. enc_embed_input = self.word_embedding_proj(input_pad) # [batch_size * N, L, D]
  229. sent_pos_list = []
  230. for sentlen in enc_sent_len:
  231. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  232. for k in range(self.sent_max_len - sentlen):
  233. sent_pos.append(0)
  234. sent_pos_list.append(sent_pos)
  235. input_pos = torch.Tensor(sent_pos_list).long()
  236. if self._hps.cuda:
  237. input_pos = input_pos.cuda()
  238. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  239. enc_conv_input = enc_embed_input + enc_pos_embed_input
  240. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  241. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  242. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  243. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  244. sent_embedding = sent_embedding.view(batch_size, N, -1)
  245. return sent_embedding
  246. class BertTagEncoder(BertEncoder):
  247. def __init__(self, hps, domaindict):
  248. super(BertTagEncoder, self).__init__(hps)
  249. # domain embedding
  250. self.domain_embedding = nn.Embedding(domaindict.size(), hps.domain_emb_dim)
  251. self.domain_embedding.weight.requires_grad = True
  252. def forward(self, inputs, input_masks, enc_sent_len, domain):
  253. sent_embedding = super().forward(inputs, input_masks, enc_sent_len)
  254. batch_size, N = enc_sent_len.size()
  255. enc_domain_input = self.domain_embedding(domain) # [batch, D]
  256. enc_domain_input = enc_domain_input.unsqueeze(1).expand(batch_size, N, -1) # [batch, N, D]
  257. sent_embedding = torch.cat((sent_embedding, enc_domain_input), dim=2)
  258. return sent_embedding
  259. class ELMoEndoer(nn.Module):
  260. def __init__(self, hps):
  261. super(ELMoEndoer, self).__init__()
  262. self._hps = hps
  263. self.sent_max_len = hps.sent_max_len
  264. from allennlp.modules.elmo import Elmo
  265. elmo_dim = 1024
  266. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  267. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  268. # elmo_dim = 512
  269. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  270. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  271. embed_size = hps.word_emb_dim
  272. sent_max_len = hps.sent_max_len
  273. input_channels = 1
  274. out_channels = hps.output_channel
  275. min_kernel_size = hps.min_kernel_size
  276. max_kernel_size = hps.max_kernel_size
  277. width = embed_size
  278. # elmo embedding
  279. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  280. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  281. # position embedding
  282. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  283. # cnn
  284. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  285. logger.info("[INFO] Initing W for CNN.......")
  286. for conv in self.convs:
  287. init_weight_value = 6.0
  288. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  289. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  290. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  291. def calculate_fan_in_and_fan_out(tensor):
  292. dimensions = tensor.ndimension()
  293. if dimensions < 2:
  294. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  295. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  296. if dimensions == 2: # Linear
  297. fan_in = tensor.size(1)
  298. fan_out = tensor.size(0)
  299. else:
  300. num_input_fmaps = tensor.size(1)
  301. num_output_fmaps = tensor.size(0)
  302. receptive_field_size = 1
  303. if tensor.dim() > 2:
  304. receptive_field_size = tensor[0][0].numel()
  305. fan_in = num_input_fmaps * receptive_field_size
  306. fan_out = num_output_fmaps * receptive_field_size
  307. return fan_in, fan_out
  308. def forward(self, input):
  309. # input: a batch of Example object [batch_size, N, seq_len, character_len]
  310. batch_size, N, seq_len, _ = input.size()
  311. input = input.view(batch_size * N, seq_len, -1) # [batch_size*N, seq_len, character_len]
  312. input_sent_len = ((input.sum(-1)!=0).sum(dim=1)).int() # [batch_size*N, 1]
  313. # logger.debug(input_sent_len.view(batch_size, -1))
  314. enc_embed_input = self.elmo(input)['elmo_representations'][0] # [batch_size*N, L, D]
  315. enc_embed_input = self.embed_proj(enc_embed_input)
  316. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  317. sent_pos_list = []
  318. for sentlen in input_sent_len:
  319. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  320. for k in range(self.sent_max_len - sentlen):
  321. sent_pos.append(0)
  322. sent_pos_list.append(sent_pos)
  323. input_pos = torch.Tensor(sent_pos_list).long()
  324. if self._hps.cuda:
  325. input_pos = input_pos.cuda()
  326. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  327. enc_conv_input = enc_embed_input + enc_pos_embed_input
  328. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  329. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  330. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  331. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  332. sent_embedding = sent_embedding.view(batch_size, N, -1)
  333. return sent_embedding
  334. class ELMoEndoer2(nn.Module):
  335. def __init__(self, hps):
  336. super(ELMoEndoer2, self).__init__()
  337. self._hps = hps
  338. self._cuda = hps.cuda
  339. self.sent_max_len = hps.sent_max_len
  340. from allennlp.modules.elmo import Elmo
  341. elmo_dim = 1024
  342. options_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
  343. weight_file = "/remote-home/dqwang/ELMo/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
  344. # elmo_dim = 512
  345. # options_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_options.json"
  346. # weight_file = "/remote-home/dqwang/ELMo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
  347. embed_size = hps.word_emb_dim
  348. sent_max_len = hps.sent_max_len
  349. input_channels = 1
  350. out_channels = hps.output_channel
  351. min_kernel_size = hps.min_kernel_size
  352. max_kernel_size = hps.max_kernel_size
  353. width = embed_size
  354. # elmo embedding
  355. self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
  356. self.embed_proj = nn.Linear(elmo_dim, embed_size)
  357. # position embedding
  358. self.position_embedding = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(sent_max_len + 1, embed_size, padding_idx=0), freeze=True)
  359. # cnn
  360. self.convs = nn.ModuleList([nn.Conv2d(input_channels, out_channels, kernel_size = (height, width)) for height in range(min_kernel_size, max_kernel_size+1)])
  361. logger.info("[INFO] Initing W for CNN.......")
  362. for conv in self.convs:
  363. init_weight_value = 6.0
  364. init.xavier_normal_(conv.weight.data, gain=np.sqrt(init_weight_value))
  365. fan_in, fan_out = Encoder.calculate_fan_in_and_fan_out(conv.weight.data)
  366. std = np.sqrt(init_weight_value) * np.sqrt(2.0 / (fan_in + fan_out))
  367. def calculate_fan_in_and_fan_out(tensor):
  368. dimensions = tensor.ndimension()
  369. if dimensions < 2:
  370. logger.error("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  371. raise ValueError("[Error] Fan in and fan out can not be computed for tensor with less than 2 dimensions")
  372. if dimensions == 2: # Linear
  373. fan_in = tensor.size(1)
  374. fan_out = tensor.size(0)
  375. else:
  376. num_input_fmaps = tensor.size(1)
  377. num_output_fmaps = tensor.size(0)
  378. receptive_field_size = 1
  379. if tensor.dim() > 2:
  380. receptive_field_size = tensor[0][0].numel()
  381. fan_in = num_input_fmaps * receptive_field_size
  382. fan_out = num_output_fmaps * receptive_field_size
  383. return fan_in, fan_out
  384. def pad_encoder_input(self, input_list):
  385. """
  386. :param input_list: N [seq_len, hidden_state]
  387. :return: enc_sent_input_pad: list, N [max_len, hidden_state]
  388. """
  389. max_len = self.sent_max_len
  390. enc_sent_input_pad = []
  391. _, hidden_size = input_list[0].size()
  392. for i in range(len(input_list)):
  393. article_words = input_list[i] # [seq_len, hidden_size]
  394. seq_len = article_words.size(0)
  395. if seq_len > max_len:
  396. pad_words = article_words[:max_len, :]
  397. else:
  398. pad_tensor = torch.zeros(max_len - seq_len, hidden_size).cuda() if self._cuda else torch.zeros(max_len - seq_len, hidden_size)
  399. pad_words = torch.cat([article_words, pad_tensor], dim=0)
  400. enc_sent_input_pad.append(pad_words)
  401. return enc_sent_input_pad
  402. def forward(self, inputs, input_masks, enc_sent_len):
  403. """
  404. :param inputs: a batch of Example object [batch_size, doc_len=512, character_len=50]
  405. :param input_masks: 0 or 1, [batch, doc_len=512]
  406. :param enc_sent_len: sentence original length [batch, N]
  407. :return:
  408. sent_embedding: [batch, N, D]
  409. """
  410. # Use Bert to get word embedding
  411. batch_size, N = enc_sent_len.size()
  412. input_pad_list = []
  413. elmo_output = self.elmo(inputs)['elmo_representations'][0] # [batch_size, 512, D]
  414. elmo_output = elmo_output * input_masks.unsqueeze(-1).float()
  415. # print("END elmo")
  416. for i in range(batch_size):
  417. sent_len = enc_sent_len[i] # [1, N]
  418. out = elmo_output[i]
  419. _, hidden_size = out.size()
  420. # restore the sentence
  421. last_end = 0
  422. enc_sent_input = []
  423. for length in sent_len:
  424. if length != 0 and last_end < 512:
  425. enc_sent_input.append(out[last_end : min(512, last_end + length), :])
  426. last_end += length
  427. else:
  428. pad_tensor = torch.zeros(self.sent_max_len, hidden_size).cuda() if self._hps.cuda else torch.zeros(self.sent_max_len, hidden_size)
  429. enc_sent_input.append(pad_tensor)
  430. # pad the sentence
  431. enc_sent_input_pad = self.pad_encoder_input(enc_sent_input) # [N, seq_len, hidden_state=4096]
  432. input_pad_list.append(torch.stack(enc_sent_input_pad)) # batch * [N, max_len, hidden_state]
  433. input_pad = torch.stack(input_pad_list)
  434. input_pad = input_pad.view(batch_size * N, self.sent_max_len, -1)
  435. enc_sent_len = enc_sent_len.view(-1) # [batch_size*N]
  436. enc_embed_input = self.embed_proj(input_pad) # [batch_size * N, L, D]
  437. # input_pos = torch.Tensor([np.hstack((np.arange(1, sentlen + 1), np.zeros(self.sent_max_len - sentlen))) for sentlen in input_sent_len])
  438. sent_pos_list = []
  439. for sentlen in enc_sent_len:
  440. sent_pos = list(range(1, min(self.sent_max_len, sentlen) + 1))
  441. for k in range(self.sent_max_len - sentlen):
  442. sent_pos.append(0)
  443. sent_pos_list.append(sent_pos)
  444. input_pos = torch.Tensor(sent_pos_list).long()
  445. if self._hps.cuda:
  446. input_pos = input_pos.cuda()
  447. enc_pos_embed_input = self.position_embedding(input_pos.long()) # [batch_size*N, D]
  448. enc_conv_input = enc_embed_input + enc_pos_embed_input
  449. enc_conv_input = enc_conv_input.unsqueeze(1) # (batch * N,Ci,L,D)
  450. enc_conv_output = [F.relu(conv(enc_conv_input)).squeeze(3) for conv in self.convs] # kernel_sizes * (batch*N, Co, W)
  451. enc_maxpool_output = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in enc_conv_output] # kernel_sizes * (batch*N, Co)
  452. sent_embedding = torch.cat(enc_maxpool_output, 1) # (batch*N, Co * kernel_sizes)
  453. sent_embedding = sent_embedding.view(batch_size, N, -1)
  454. return sent_embedding