You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mwan.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import torch as tc
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import sys
  5. import os
  6. import math
  7. from fastNLP.core.const import Const
  8. class RNNModel(nn.Module):
  9. def __init__(self, input_size, hidden_size, num_layers, bidrect, dropout):
  10. super(RNNModel, self).__init__()
  11. if num_layers <= 1:
  12. dropout = 0.0
  13. self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
  14. batch_first=True, dropout=dropout, bidirectional=bidrect)
  15. self.number = (2 if bidrect else 1) * num_layers
  16. def forward(self, x, mask):
  17. '''
  18. mask: (batch_size, seq_len)
  19. x: (batch_size, seq_len, input_size)
  20. '''
  21. lens = (mask).long().sum(dim=1)
  22. lens, idx_sort = tc.sort(lens, descending=True)
  23. _, idx_unsort = tc.sort(idx_sort)
  24. x = x[idx_sort]
  25. x = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True)
  26. self.rnn.flatten_parameters()
  27. y, h = self.rnn(x)
  28. y, lens = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
  29. h = h.transpose(0,1).contiguous() #make batch size first
  30. y = y[idx_unsort] #(batch_size, seq_len, bid * hid_size)
  31. h = h[idx_unsort] #(batch_size, number, hid_size)
  32. return y, h
  33. class Contexualizer(nn.Module):
  34. def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.3):
  35. super(Contexualizer, self).__init__()
  36. self.rnn = RNNModel(input_size, hidden_size, num_layers, True, dropout)
  37. self.output_size = hidden_size * 2
  38. self.reset_parameters()
  39. def reset_parameters(self):
  40. weights = self.rnn.rnn.all_weights
  41. for w1 in weights:
  42. for w2 in w1:
  43. if len(list(w2.size())) <= 1:
  44. w2.data.fill_(0)
  45. else: nn.init.xavier_normal_(w2.data, gain=1.414)
  46. def forward(self, s, mask):
  47. y = self.rnn(s, mask)[0] # (batch_size, seq_len, 2 * hidden_size)
  48. return y
  49. class ConcatAttention_Param(nn.Module):
  50. def __init__(self, input_size, hidden_size, dropout=0.2):
  51. super(ConcatAttention_Param, self).__init__()
  52. self.ln = nn.Linear(input_size + hidden_size, hidden_size)
  53. self.v = nn.Linear(hidden_size, 1, bias=False)
  54. self.vq = nn.Parameter(tc.rand(hidden_size))
  55. self.drop = nn.Dropout(dropout)
  56. self.output_size = input_size
  57. self.reset_parameters()
  58. def reset_parameters(self):
  59. nn.init.xavier_uniform_(self.v.weight.data)
  60. nn.init.xavier_uniform_(self.ln.weight.data)
  61. self.ln.bias.data.fill_(0)
  62. def forward(self, h, mask):
  63. '''
  64. h: (batch_size, len, input_size)
  65. mask: (batch_size, len)
  66. '''
  67. vq = self.vq.view(1,1,-1).expand(h.size(0), h.size(1), self.vq.size(0))
  68. s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len)
  69. s = s - ((mask == 0).float() * 10000)
  70. a = tc.softmax(s, dim=1)
  71. r = a.unsqueeze(-1) * h # (batch_size, len, input_size)
  72. r = tc.sum(r, dim=1) # (batch_size, input_size)
  73. return self.drop(r)
  74. def get_2dmask(mask_hq, mask_hp, siz=None):
  75. if siz is None:
  76. siz = (mask_hq.size(0), mask_hq.size(1), mask_hp.size(1))
  77. mask_mat = 1
  78. if mask_hq is not None:
  79. mask_mat = mask_mat * mask_hq.unsqueeze(2).expand(siz)
  80. if mask_hp is not None:
  81. mask_mat = mask_mat * mask_hp.unsqueeze(1).expand(siz)
  82. return mask_mat
  83. def Attention(hq, hp, mask_hq, mask_hp, my_method):
  84. standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
  85. mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])
  86. hq_mat = hq.unsqueeze(2).expand(standard_size)
  87. hp_mat = hp.unsqueeze(1).expand(standard_size)
  88. s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p)
  89. s = s - ((mask_mat == 0).float() * 10000)
  90. a = tc.softmax(s, dim=1)
  91. q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
  92. q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)
  93. return q
  94. class ConcatAttention(nn.Module):
  95. def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
  96. super(ConcatAttention, self).__init__()
  97. if input_size_2 < 0:
  98. input_size_2 = input_size
  99. self.ln = nn.Linear(input_size + input_size_2, hidden_size)
  100. self.v = nn.Linear(hidden_size, 1, bias=False)
  101. self.drop = nn.Dropout(dropout)
  102. self.output_size = input_size
  103. self.reset_parameters()
  104. def reset_parameters(self):
  105. nn.init.xavier_uniform_(self.v.weight.data)
  106. nn.init.xavier_uniform_(self.ln.weight.data)
  107. self.ln.bias.data.fill_(0)
  108. def my_method(self, hq_mat, hp_mat):
  109. s = tc.cat([hq_mat, hp_mat], dim=-1)
  110. s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p)
  111. return s
  112. def forward(self, hq, hp, mask_hq=None, mask_hp=None):
  113. '''
  114. hq: (batch_size, len_q, input_size)
  115. mask_hq: (batch_size, len_q)
  116. '''
  117. return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
  118. class MinusAttention(nn.Module):
  119. def __init__(self, input_size, hidden_size, dropout=0.2):
  120. super(MinusAttention, self).__init__()
  121. self.ln = nn.Linear(input_size, hidden_size)
  122. self.v = nn.Linear(hidden_size, 1, bias=False)
  123. self.drop = nn.Dropout(dropout)
  124. self.output_size = input_size
  125. self.reset_parameters()
  126. def reset_parameters(self):
  127. nn.init.xavier_uniform_(self.v.weight.data)
  128. nn.init.xavier_uniform_(self.ln.weight.data)
  129. self.ln.bias.data.fill_(0)
  130. def my_method(self, hq_mat, hp_mat):
  131. s = hq_mat - hp_mat
  132. s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
  133. return s
  134. def forward(self, hq, hp, mask_hq=None, mask_hp=None):
  135. return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
  136. class DotProductAttention(nn.Module):
  137. def __init__(self, input_size, hidden_size, dropout=0.2):
  138. super(DotProductAttention, self).__init__()
  139. self.ln = nn.Linear(input_size, hidden_size)
  140. self.v = nn.Linear(hidden_size, 1, bias=False)
  141. self.drop = nn.Dropout(dropout)
  142. self.output_size = input_size
  143. self.reset_parameters()
  144. def reset_parameters(self):
  145. nn.init.xavier_uniform_(self.v.weight.data)
  146. nn.init.xavier_uniform_(self.ln.weight.data)
  147. self.ln.bias.data.fill_(0)
  148. def my_method(self, hq_mat, hp_mat):
  149. s = hq_mat * hp_mat
  150. s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
  151. return s
  152. def forward(self, hq, hp, mask_hq=None, mask_hp=None):
  153. return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))
  154. class BiLinearAttention(nn.Module):
  155. def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
  156. super(BiLinearAttention, self).__init__()
  157. input_size_2 = input_size if input_size_2 < 0 else input_size_2
  158. self.ln = nn.Linear(input_size_2, input_size)
  159. self.drop = nn.Dropout(dropout)
  160. self.output_size = input_size
  161. self.reset_parameters()
  162. def reset_parameters(self):
  163. nn.init.xavier_uniform_(self.ln.weight.data)
  164. self.ln.bias.data.fill_(0)
  165. def my_method(self, hq, hp, mask_p):
  166. # (bs, len, input_size)
  167. hp = self.ln(hp)
  168. hp = hp * mask_p.unsqueeze(-1)
  169. s = tc.matmul(hq, hp.transpose(-1,-2))
  170. return s
  171. def forward(self, hq, hp, mask_hq=None, mask_hp=None):
  172. standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
  173. mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])
  174. s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p)
  175. s = s - ((mask_mat == 0).float() * 10000)
  176. a = tc.softmax(s, dim=1)
  177. hq_mat = hq.unsqueeze(2).expand(standard_size)
  178. q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
  179. q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)
  180. return self.drop(q)
  181. class AggAttention(nn.Module):
  182. def __init__(self, input_size, hidden_size, dropout=0.2):
  183. super(AggAttention, self).__init__()
  184. self.ln = nn.Linear(input_size + hidden_size, hidden_size)
  185. self.v = nn.Linear(hidden_size, 1, bias=False)
  186. self.vq = nn.Parameter(tc.rand(hidden_size, 1))
  187. self.drop = nn.Dropout(dropout)
  188. self.output_size = input_size
  189. self.reset_parameters()
  190. def reset_parameters(self):
  191. nn.init.xavier_uniform_(self.vq.data)
  192. nn.init.xavier_uniform_(self.v.weight.data)
  193. nn.init.xavier_uniform_(self.ln.weight.data)
  194. self.ln.bias.data.fill_(0)
  195. self.vq.data = self.vq.data[:,0]
  196. def forward(self, hs, mask):
  197. '''
  198. hs: [(batch_size, len_q, input_size), ...]
  199. mask: (batch_size, len_q)
  200. '''
  201. hs = tc.cat([h.unsqueeze(0) for h in hs], dim=0)# (4, batch_size, len_q, input_size)
  202. vq = self.vq.view(1,1,1,-1).expand(hs.size(0), hs.size(1), hs.size(2), self.vq.size(0))
  203. s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q)
  204. s = s - ((mask.unsqueeze(0) == 0).float() * 10000)
  205. a = tc.softmax(s, dim=0)
  206. x = a.unsqueeze(-1) * hs
  207. x = tc.sum(x, dim=0)#(batch_size, len_q, input_size)
  208. return self.drop(x)
  209. class Aggragator(nn.Module):
  210. def __init__(self, input_size, hidden_size, dropout=0.3):
  211. super(Aggragator, self).__init__()
  212. now_size = input_size
  213. self.ln = nn.Linear(2 * input_size, 2 * input_size)
  214. now_size = 2 * input_size
  215. self.rnn = Contexualizer(now_size, hidden_size, 2, dropout)
  216. now_size = self.rnn.output_size
  217. self.agg_att = AggAttention(now_size, now_size, dropout)
  218. now_size = self.agg_att.output_size
  219. self.agg_rnn = Contexualizer(now_size, hidden_size, 2, dropout)
  220. self.drop = nn.Dropout(dropout)
  221. self.output_size = self.agg_rnn.output_size
  222. def forward(self, qs, hp, mask):
  223. '''
  224. qs: [ (batch_size, len_p, input_size), ...]
  225. hp: (batch_size, len_p, input_size)
  226. mask if the same of hp's mask
  227. '''
  228. hs = [0 for _ in range(len(qs))]
  229. for i in range(len(qs)):
  230. q = qs[i]
  231. x = tc.cat([q, hp], dim=-1)
  232. g = tc.sigmoid(self.ln(x))
  233. x_star = x * g
  234. h = self.rnn(x_star, mask)
  235. hs[i] = h
  236. x = self.agg_att(hs, mask) #(batch_size, len_p, output_size)
  237. h = self.agg_rnn(x, mask) #(batch_size, len_p, output_size)
  238. return self.drop(h)
  239. class Mwan_Imm(nn.Module):
  240. def __init__(self, input_size, hidden_size, num_class=3, dropout=0.2, use_allennlp=False):
  241. super(Mwan_Imm, self).__init__()
  242. now_size = input_size
  243. self.enc_s1 = Contexualizer(now_size, hidden_size, 2, dropout)
  244. self.enc_s2 = Contexualizer(now_size, hidden_size, 2, dropout)
  245. now_size = self.enc_s1.output_size
  246. self.att_c = ConcatAttention(now_size, hidden_size, dropout)
  247. self.att_b = BiLinearAttention(now_size, hidden_size, dropout)
  248. self.att_d = DotProductAttention(now_size, hidden_size, dropout)
  249. self.att_m = MinusAttention(now_size, hidden_size, dropout)
  250. now_size = self.att_c.output_size
  251. self.agg = Aggragator(now_size, hidden_size, dropout)
  252. now_size = self.enc_s1.output_size
  253. self.pred_1 = ConcatAttention_Param(now_size, hidden_size, dropout)
  254. now_size = self.agg.output_size
  255. self.pred_2 = ConcatAttention(now_size, hidden_size, dropout,
  256. input_size_2=self.pred_1.output_size)
  257. now_size = self.pred_2.output_size
  258. self.ln1 = nn.Linear(now_size, hidden_size)
  259. self.ln2 = nn.Linear(hidden_size, num_class)
  260. self.reset_parameters()
  261. def reset_parameters(self):
  262. nn.init.xavier_uniform_(self.ln1.weight.data)
  263. nn.init.xavier_uniform_(self.ln2.weight.data)
  264. self.ln1.bias.data.fill_(0)
  265. self.ln2.bias.data.fill_(0)
  266. def forward(self, s1, s2, mas_s1, mas_s2):
  267. hq = self.enc_s1(s1, mas_s1) #(batch_size, len_q, output_size)
  268. hp = self.enc_s1(s2, mas_s2)
  269. mas_s1 = mas_s1[:,:hq.size(1)]
  270. mas_s2 = mas_s2[:,:hp.size(1)]
  271. mas_q, mas_p = mas_s1, mas_s2
  272. qc = self.att_c(hq, hp, mas_s1, mas_s2) #(batch_size, len_p, output_size)
  273. qb = self.att_b(hq, hp, mas_s1, mas_s2)
  274. qd = self.att_d(hq, hp, mas_s1, mas_s2)
  275. qm = self.att_m(hq, hp, mas_s1, mas_s2)
  276. ho = self.agg([qc,qb,qd,qm], hp, mas_s2) #(batch_size, len_p, output_size)
  277. rq = self.pred_1(hq, mas_q) #(batch_size, output_size)
  278. rp = self.pred_2(ho, rq.unsqueeze(1), mas_p)#(batch_size, 1, output_size)
  279. rp = rp.squeeze(1) #(batch_size, output_size)
  280. rp = F.relu(self.ln1(rp))
  281. rp = self.ln2(rp)
  282. return rp
  283. class MwanModel(nn.Module):
  284. def __init__(self, num_class, EmbLayer, args_of_imm={}, ElmoLayer=None):
  285. super(MwanModel, self).__init__()
  286. self.emb = EmbLayer
  287. if ElmoLayer is not None:
  288. self.elmo = ElmoLayer
  289. self.elmo_preln = nn.Linear(3 * self.elmo.emb_size, self.elmo.emb_size)
  290. self.elmo_ln = nn.Linear(args_of_imm["input_size"] +
  291. self.elmo.emb_size, args_of_imm["input_size"])
  292. else:
  293. self.elmo = None
  294. self.imm = Mwan_Imm(num_class=num_class, **args_of_imm)
  295. self.drop = nn.Dropout(args_of_imm["dropout"])
  296. def forward(self, words1, words2, str_s1=None, str_s2=None, *pargs, **kwargs):
  297. '''
  298. str_s is for elmo use , however we don't use elmo
  299. str_s: (batch_size, seq_len, word_len)
  300. '''
  301. s1, s2 = words1, words2
  302. mas_s1 = (s1 != 0).float() # mas: (batch_size, seq_len)
  303. mas_s2 = (s2 != 0).float() # mas: (batch_size, seq_len)
  304. mas_s1.requires_grad = False
  305. mas_s2.requires_grad = False
  306. s1_emb = self.emb(s1)
  307. s2_emb = self.emb(s2)
  308. if self.elmo is not None:
  309. s1_elmo = self.elmo(str_s1)
  310. s2_elmo = self.elmo(str_s2)
  311. s1_elmo = tc.tanh(self.elmo_preln(tc.cat(s1_elmo, dim=-1)))
  312. s2_elmo = tc.tanh(self.elmo_preln(tc.cat(s2_elmo, dim=-1)))
  313. s1_emb = tc.cat([s1_emb, s1_elmo], dim=-1)
  314. s2_emb = tc.cat([s2_emb, s2_elmo], dim=-1)
  315. s1_emb = tc.tanh(self.elmo_ln(s1_emb))
  316. s2_emb = tc.tanh(self.elmo_ln(s2_emb))
  317. s1_emb = self.drop(s1_emb)
  318. s2_emb = self.drop(s2_emb)
  319. y = self.imm(s1_emb, s2_emb, mas_s1, mas_s2)
  320. return {
  321. Const.OUTPUT: y,
  322. }