You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modules.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. import torch.nn as nn
  2. import torch
  3. from fastNLP.core.utils import seq_len_to_mask
  4. from utils import better_init_rnn
  5. import numpy as np
  6. class WordLSTMCell_yangjie(nn.Module):
  7. """A basic LSTM cell."""
  8. def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True):
  9. """
  10. Most parts are copied from torch.nn.LSTMCell.
  11. """
  12. super().__init__()
  13. self.left2right = left2right
  14. self.debug = debug
  15. self.input_size = input_size
  16. self.hidden_size = hidden_size
  17. self.use_bias = use_bias
  18. self.weight_ih = nn.Parameter(
  19. torch.FloatTensor(input_size, 3 * hidden_size))
  20. self.weight_hh = nn.Parameter(
  21. torch.FloatTensor(hidden_size, 3 * hidden_size))
  22. if use_bias:
  23. self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
  24. else:
  25. self.register_parameter('bias', None)
  26. self.reset_parameters()
  27. def reset_parameters(self):
  28. """
  29. Initialize parameters following the way proposed in the paper.
  30. """
  31. nn.init.orthogonal(self.weight_ih.data)
  32. weight_hh_data = torch.eye(self.hidden_size)
  33. weight_hh_data = weight_hh_data.repeat(1, 3)
  34. with torch.no_grad():
  35. self.weight_hh.set_(weight_hh_data)
  36. # The bias is just set to zero vectors.
  37. if self.use_bias:
  38. nn.init.constant(self.bias.data, val=0)
  39. def forward(self, input_, hx):
  40. """
  41. Args:
  42. input_: A (batch, input_size) tensor containing input
  43. features.
  44. hx: A tuple (h_0, c_0), which contains the initial hidden
  45. and cell state, where the size of both states is
  46. (batch, hidden_size).
  47. Returns:
  48. h_1, c_1: Tensors containing the next hidden and cell state.
  49. """
  50. h_0, c_0 = hx
  51. batch_size = h_0.size(0)
  52. bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
  53. wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
  54. wi = torch.mm(input_, self.weight_ih)
  55. f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
  56. c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
  57. return c_1
  58. def __repr__(self):
  59. s = '{name}({input_size}, {hidden_size})'
  60. return s.format(name=self.__class__.__name__, **self.__dict__)
  61. class MultiInputLSTMCell_V0(nn.Module):
  62. def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
  63. super().__init__()
  64. self.char_input_size = char_input_size
  65. self.hidden_size = hidden_size
  66. self.use_bias = use_bias
  67. self.weight_ih = nn.Parameter(
  68. torch.FloatTensor(char_input_size, 3 * hidden_size)
  69. )
  70. self.weight_hh = nn.Parameter(
  71. torch.FloatTensor(hidden_size, 3 * hidden_size)
  72. )
  73. self.alpha_weight_ih = nn.Parameter(
  74. torch.FloatTensor(char_input_size, hidden_size)
  75. )
  76. self.alpha_weight_hh = nn.Parameter(
  77. torch.FloatTensor(hidden_size, hidden_size)
  78. )
  79. if self.use_bias:
  80. self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
  81. self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
  82. else:
  83. self.register_parameter('bias', None)
  84. self.register_parameter('alpha_bias', None)
  85. self.debug = debug
  86. self.reset_parameters()
  87. def reset_parameters(self):
  88. """
  89. Initialize parameters following the way proposed in the paper.
  90. """
  91. nn.init.orthogonal(self.weight_ih.data)
  92. nn.init.orthogonal(self.alpha_weight_ih.data)
  93. weight_hh_data = torch.eye(self.hidden_size)
  94. weight_hh_data = weight_hh_data.repeat(1, 3)
  95. with torch.no_grad():
  96. self.weight_hh.set_(weight_hh_data)
  97. alpha_weight_hh_data = torch.eye(self.hidden_size)
  98. alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
  99. with torch.no_grad():
  100. self.alpha_weight_hh.set_(alpha_weight_hh_data)
  101. # The bias is just set to zero vectors.
  102. if self.use_bias:
  103. nn.init.constant_(self.bias.data, val=0)
  104. nn.init.constant_(self.alpha_bias.data, val=0)
  105. def forward(self, inp, skip_c, skip_count, hx):
  106. '''
  107. :param inp: chars B * hidden
  108. :param skip_c: 由跳边得到的c, B * X * hidden
  109. :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
  110. :param hx:
  111. :return:
  112. '''
  113. max_skip_count = torch.max(skip_count).item()
  114. if True:
  115. h_0, c_0 = hx
  116. batch_size = h_0.size(0)
  117. bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
  118. wi = torch.matmul(inp, self.weight_ih)
  119. wh = torch.matmul(h_0, self.weight_hh)
  120. i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
  121. i = torch.sigmoid(i).unsqueeze(1)
  122. o = torch.sigmoid(o).unsqueeze(1)
  123. g = torch.tanh(g).unsqueeze(1)
  124. alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
  125. alpha_wi.unsqueeze_(1)
  126. # alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size)
  127. alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)
  128. alpha_bias_batch = self.alpha_bias.unsqueeze(0)
  129. alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)
  130. skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])
  131. skip_mask = 1 - skip_mask
  132. skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)
  133. skip_mask = (skip_mask).float()*1e20
  134. alpha = alpha - skip_mask
  135. alpha = torch.exp(torch.cat([i, alpha], dim=1))
  136. alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
  137. alpha = torch.div(alpha, alpha_sum)
  138. merge_i_c = torch.cat([g, skip_c], dim=1)
  139. c_1 = merge_i_c * alpha
  140. c_1 = c_1.sum(1, keepdim=True)
  141. # h_1 = o * c_1
  142. h_1 = o * torch.tanh(c_1)
  143. return h_1.squeeze(1), c_1.squeeze(1)
  144. else:
  145. h_0, c_0 = hx
  146. batch_size = h_0.size(0)
  147. bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
  148. wi = torch.matmul(inp, self.weight_ih)
  149. wh = torch.matmul(h_0, self.weight_hh)
  150. i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
  151. i = torch.sigmoid(i).unsqueeze(1)
  152. o = torch.sigmoid(o).unsqueeze(1)
  153. g = torch.tanh(g).unsqueeze(1)
  154. c_1 = g
  155. h_1 = o * c_1
  156. return h_1,c_1
  157. class MultiInputLSTMCell_V1(nn.Module):
  158. def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
  159. super().__init__()
  160. self.char_input_size = char_input_size
  161. self.hidden_size = hidden_size
  162. self.use_bias = use_bias
  163. self.weight_ih = nn.Parameter(
  164. torch.FloatTensor(char_input_size, 3 * hidden_size)
  165. )
  166. self.weight_hh = nn.Parameter(
  167. torch.FloatTensor(hidden_size, 3 * hidden_size)
  168. )
  169. self.alpha_weight_ih = nn.Parameter(
  170. torch.FloatTensor(char_input_size, hidden_size)
  171. )
  172. self.alpha_weight_hh = nn.Parameter(
  173. torch.FloatTensor(hidden_size, hidden_size)
  174. )
  175. if self.use_bias:
  176. self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
  177. self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
  178. else:
  179. self.register_parameter('bias', None)
  180. self.register_parameter('alpha_bias', None)
  181. self.debug = debug
  182. self.reset_parameters()
  183. def reset_parameters(self):
  184. """
  185. Initialize parameters following the way proposed in the paper.
  186. """
  187. nn.init.orthogonal(self.weight_ih.data)
  188. nn.init.orthogonal(self.alpha_weight_ih.data)
  189. weight_hh_data = torch.eye(self.hidden_size)
  190. weight_hh_data = weight_hh_data.repeat(1, 3)
  191. with torch.no_grad():
  192. self.weight_hh.set_(weight_hh_data)
  193. alpha_weight_hh_data = torch.eye(self.hidden_size)
  194. alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
  195. with torch.no_grad():
  196. self.alpha_weight_hh.set_(alpha_weight_hh_data)
  197. # The bias is just set to zero vectors.
  198. if self.use_bias:
  199. nn.init.constant_(self.bias.data, val=0)
  200. nn.init.constant_(self.alpha_bias.data, val=0)
  201. def forward(self, inp, skip_c, skip_count, hx):
  202. '''
  203. :param inp: chars B * hidden
  204. :param skip_c: 由跳边得到的c, B * X * hidden
  205. :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
  206. :param hx:
  207. :return:
  208. '''
  209. max_skip_count = torch.max(skip_count).item()
  210. if True:
  211. h_0, c_0 = hx
  212. batch_size = h_0.size(0)
  213. bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
  214. wi = torch.matmul(inp, self.weight_ih)
  215. wh = torch.matmul(h_0, self.weight_hh)
  216. i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
  217. i = torch.sigmoid(i).unsqueeze(1)
  218. o = torch.sigmoid(o).unsqueeze(1)
  219. g = torch.tanh(g).unsqueeze(1)
  220. ##basic lstm start
  221. f = 1 - i
  222. c_1_basic = f*c_0.unsqueeze(1) + i*g
  223. c_1_basic = c_1_basic.squeeze(1)
  224. alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
  225. alpha_wi.unsqueeze_(1)
  226. alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)
  227. alpha_bias_batch = self.alpha_bias.unsqueeze(0)
  228. alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)
  229. skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])
  230. skip_mask = 1 - skip_mask
  231. skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)
  232. skip_mask = (skip_mask).float()*1e20
  233. alpha = alpha - skip_mask
  234. alpha = torch.exp(torch.cat([i, alpha], dim=1))
  235. alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
  236. alpha = torch.div(alpha, alpha_sum)
  237. merge_i_c = torch.cat([g, skip_c], dim=1)
  238. c_1 = merge_i_c * alpha
  239. c_1 = c_1.sum(1, keepdim=True)
  240. # h_1 = o * c_1
  241. c_1 = c_1.squeeze(1)
  242. count_select = (skip_count != 0).float().unsqueeze(-1)
  243. c_1 = c_1*count_select + c_1_basic*(1-count_select)
  244. o = o.squeeze(1)
  245. h_1 = o * torch.tanh(c_1)
  246. return h_1, c_1
  247. class LatticeLSTMLayer_sup_back_V0(nn.Module):
  248. def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
  249. bias=True,device=None,debug=False,skip_before_head=False):
  250. super().__init__()
  251. self.skip_before_head = skip_before_head
  252. self.hidden_size = hidden_size
  253. self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug)
  254. self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)
  255. self.word_input_size = word_input_size
  256. self.left2right = left2right
  257. self.bias = bias
  258. self.device = device
  259. self.debug = debug
  260. def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
  261. '''
  262. :param inp: batch * seq_len * embedding, chars
  263. :param seq_len: batch, length of chars
  264. :param skip_sources: batch * seq_len * X, 跳边的起点
  265. :param skip_words: batch * seq_len * X * embedding, 跳边的词
  266. :param lexicon_count: batch * seq_len, count of lexicon per example per position
  267. :param init_state: the hx of rnn
  268. :return:
  269. '''
  270. if self.left2right:
  271. max_seq_len = max(seq_len)
  272. batch_size = inp.size(0)
  273. c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  274. h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  275. for i in range(max_seq_len):
  276. max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
  277. h_0, c_0 = h_[:, i, :], c_[:, i, :]
  278. skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
  279. skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
  280. skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
  281. index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
  282. index_1 = skip_source_flat
  283. if not self.skip_before_head:
  284. c_x = c_[[index_0, index_1+1]]
  285. h_x = h_[[index_0, index_1+1]]
  286. else:
  287. c_x = c_[[index_0,index_1]]
  288. h_x = h_[[index_0,index_1]]
  289. c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
  290. h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
  291. c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
  292. c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
  293. h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
  294. h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
  295. c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)
  296. return h_[:,1:],c_[:,1:]
  297. else:
  298. mask_for_seq_len = seq_len_to_mask(seq_len)
  299. max_seq_len = max(seq_len)
  300. batch_size = inp.size(0)
  301. c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  302. h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  303. for i in reversed(range(max_seq_len)):
  304. max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
  305. h_0, c_0 = h_[:, 0, :], c_[:, 0, :]
  306. skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
  307. skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
  308. skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
  309. index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
  310. index_1 = skip_source_flat-i
  311. if not self.skip_before_head:
  312. c_x = c_[[index_0, index_1-1]]
  313. h_x = h_[[index_0, index_1-1]]
  314. else:
  315. c_x = c_[[index_0,index_1]]
  316. h_x = h_[[index_0,index_1]]
  317. c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
  318. h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
  319. c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
  320. c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
  321. h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
  322. h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
  323. c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)
  324. h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
  325. c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)
  326. return h_[:,:-1],c_[:,:-1]
  327. class LatticeLSTMLayer_sup_back_V1(nn.Module):
  328. # V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式,
  329. # 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同
  330. def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
  331. bias=True,device=None,debug=False,skip_before_head=False):
  332. super().__init__()
  333. self.debug = debug
  334. self.skip_before_head = skip_before_head
  335. self.hidden_size = hidden_size
  336. self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug)
  337. self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)
  338. self.word_input_size = word_input_size
  339. self.left2right = left2right
  340. self.bias = bias
  341. self.device = device
  342. def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
  343. '''
  344. :param inp: batch * seq_len * embedding, chars
  345. :param seq_len: batch, length of chars
  346. :param skip_sources: batch * seq_len * X, 跳边的起点
  347. :param skip_words: batch * seq_len * X * embedding_size, 跳边的词
  348. :param lexicon_count: batch * seq_len,
  349. lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量
  350. :param init_state: the hx of rnn
  351. :return:
  352. '''
  353. if self.left2right:
  354. max_seq_len = max(seq_len)
  355. batch_size = inp.size(0)
  356. c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  357. h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  358. for i in range(max_seq_len):
  359. max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
  360. h_0, c_0 = h_[:, i, :], c_[:, i, :]
  361. #为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量
  362. #为了匹配pytorch的[]取址方式,需要将reshape成二维张量
  363. skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
  364. skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
  365. skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
  366. index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
  367. index_1 = skip_source_flat
  368. if not self.skip_before_head:
  369. c_x = c_[[index_0, index_1+1]]
  370. h_x = h_[[index_0, index_1+1]]
  371. else:
  372. c_x = c_[[index_0,index_1]]
  373. h_x = h_[[index_0,index_1]]
  374. c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
  375. h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
  376. c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
  377. c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
  378. h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
  379. h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
  380. c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)
  381. return h_[:,1:],c_[:,1:]
  382. else:
  383. mask_for_seq_len = seq_len_to_mask(seq_len)
  384. max_seq_len = max(seq_len)
  385. batch_size = inp.size(0)
  386. c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  387. h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
  388. for i in reversed(range(max_seq_len)):
  389. max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
  390. h_0, c_0 = h_[:, 0, :], c_[:, 0, :]
  391. skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
  392. skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
  393. skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
  394. index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
  395. index_1 = skip_source_flat-i
  396. if not self.skip_before_head:
  397. c_x = c_[[index_0, index_1-1]]
  398. h_x = h_[[index_0, index_1-1]]
  399. else:
  400. c_x = c_[[index_0,index_1]]
  401. h_x = h_[[index_0,index_1]]
  402. c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
  403. h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
  404. c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
  405. c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
  406. h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
  407. h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
  408. c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)
  409. h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
  410. c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)
  411. return h_[:,:-1],c_[:,:-1]