You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 10 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from torch import nn
  2. import torch
  3. import numpy as np
  4. class SemiCRFShiftRelay(nn.Module):
  5. """
  6. 该模块是一个decoder,但当前不支持含有tag的decode。
  7. """
  8. def __init__(self, L):
  9. """
  10. :param L: 不包含relay的长度
  11. """
  12. if L<2:
  13. raise RuntimeError()
  14. super().__init__()
  15. self.L = L
  16. def forward(self, logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len):
  17. """
  18. relay node是接下来L个字都不是它的结束。relay的状态是往后滑动1个位置
  19. :param logits: batch_size x max_len x L, 当前位置往左边L个segment的分数,最后一维的0是长度为1的segment(即本身)
  20. :param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数
  21. :param relay_target: batch_size x max_len 每个位置他的segment在哪里开始的。如果超过L,则一直保持为L-1。比如长度为
  22. 5的词,L=3, [0, 1, 2, 2, 2]
  23. :param relay_mask: batch_size x max_len, 在需要relay的地方为1, 长度为5的词, L=3时,为[1, 1, 1, 0, 0]
  24. :param end_seg_mask: batch_size x max_len, segment结束的地方为1。
  25. :param seq_len: batch_size, 句子的长度
  26. :return: loss: batch_size,
  27. """
  28. batch_size, max_len, L = logits.size()
  29. # 当前时刻为relay node的分数是多少
  30. relay_scores = logits.new_zeros(batch_size, max_len)
  31. # 当前时刻结束的分数是多少
  32. scores = logits.new_zeros(batch_size, max_len+1)
  33. # golden的分数
  34. gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(False), 0) + \
  35. logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(False), 0)
  36. # 初始化
  37. scores[:, 1] = logits[:, 0, 0]
  38. batch_i = torch.arange(batch_size).to(logits.device).long()
  39. relay_scores[:, 0] = relay_logits[:, 0]
  40. last_relay_index = max_len - self.L
  41. for t in range(1, max_len):
  42. real_L = min(t+1, L)
  43. flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment
  44. # 计算relay_scores的更新
  45. if t<last_relay_index:
  46. # (1) 从正常位置跳转
  47. tmp1 = relay_logits[:, t] + scores[:, t] # batch_size
  48. # (2) 从relay跳转
  49. tmp2 = relay_logits[:, t] + relay_scores[:, t-1] # batch_size
  50. tmp1 = torch.stack([tmp1, tmp2], dim=0)
  51. relay_scores[:, t] = torch.logsumexp(tmp1, dim=0)
  52. # 计算scores的更新
  53. # (1)从之前的位置跳转过来的
  54. tmp1 = scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L
  55. if t>self.L-1:
  56. # (2)从relay跳转过来的
  57. tmp2 = relay_scores[:, t-self.L] # batch_size
  58. tmp2 = tmp2 + flip_logits_t[:, 0] # batch_size
  59. tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1)
  60. scores[:, t+1] = torch.logsumexp(tmp1, dim=-1) # 更新当前时刻的分数
  61. # 计算golden
  62. seg_i = relay_target[:, t] # batch_size
  63. gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(False), 0) # batch_size, 后向从0到L长度的segment的分数
  64. relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(False), 0)
  65. gold_scores = gold_scores + relay_score + gold_segment_scores
  66. all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size
  67. return all_scores - gold_scores
  68. def predict(self, logits, relay_logits, seq_len):
  69. """
  70. relay node是接下来L个字都不是它的结束。relay的状态是往后滑动L-1个位置
  71. :param logits: batch_size x max_len x L, 当前位置左边L个segment的分数,最后一维的0是长度为1的segment(即本身)
  72. :param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数
  73. :param seq_len: batch_size, 句子的长度
  74. :return: pred: batch_size x max_len以该点开始的segment的(长度-1); pred_mask为1的地方预测有segment开始
  75. """
  76. batch_size, max_len, L = logits.size()
  77. # 当前时刻为relay node的分数是多少
  78. max_relay_scores = logits.new_zeros(batch_size, max_len)
  79. relay_bt = seq_len.new_zeros(batch_size, max_len) # 当前结果是否来自于relay的结果
  80. # 当前时刻结束的分数是多少
  81. max_scores = logits.new_zeros(batch_size, max_len+1)
  82. bt = seq_len.new_zeros(batch_size, max_len)
  83. # 初始化
  84. max_scores[:, 1] = logits[:, 0, 0]
  85. max_relay_scores[:, 0] = relay_logits[:, 0]
  86. last_relay_index = max_len - self.L
  87. for t in range(1, max_len):
  88. real_L = min(t+1, L)
  89. flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment
  90. # 计算relay_scores的更新
  91. if t<last_relay_index:
  92. # (1) 从正常位置跳转
  93. tmp1 = relay_logits[:, t] + max_scores[:, t]
  94. # (2) 从relay跳转
  95. tmp2 = relay_logits[:, t] + max_relay_scores[:, t-1] # batch_size
  96. # 每个sample的倒数L位不能是relay了
  97. tmp2 = tmp2.masked_fill(seq_len.le(t+L), float('-inf'))
  98. mask_i = tmp1.lt(tmp2) # 为1的位置为relay跳转
  99. relay_bt[:, t].masked_fill_(mask_i, 1)
  100. max_relay_scores[:, t] = torch.max(tmp1, tmp2)
  101. # 计算scores的更新
  102. # (1)从之前的位置跳转过来的
  103. tmp1 = max_scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L
  104. tmp1 = tmp1.flip(dims=[1]) # 0的位置代表长度为1的segment
  105. if self.L-1<t:
  106. # (2)从relay跳转过来的
  107. tmp2 = max_relay_scores[:, t-self.L] # batch_size
  108. tmp2 = tmp2 + flip_logits_t[:, 0]
  109. tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1)
  110. # 看哪个更大
  111. max_score, pt = torch.max(tmp1, dim=1)
  112. max_scores[:, t+1] = max_score
  113. # mask_i = pt.ge(self.L)
  114. bt[:, t] = pt # 假设L=3, 那么对于0,1,2,3分别代表的是[t, t], [t-1, t], [t-2, t], [t-self.L(relay), t]
  115. # 需要把结果decode出来
  116. pred = np.zeros((batch_size, max_len), dtype=int)
  117. pred_mask = np.zeros((batch_size, max_len), dtype=int)
  118. seq_len = seq_len.tolist()
  119. bt = bt.tolist()
  120. relay_bt = relay_bt.tolist()
  121. for b in range(batch_size):
  122. seq_len_i = seq_len[b]
  123. bt_i = bt[b][:seq_len_i]
  124. relay_bt_i = relay_bt[b][:seq_len_i]
  125. j = seq_len_i - 1
  126. assert relay_bt_i[j]!=1
  127. while j>-1:
  128. if bt_i[j]==self.L:
  129. seg_start_pos = j
  130. j = j-self.L
  131. while relay_bt_i[j]!=0 and j>-1:
  132. j = j - 1
  133. pred[b, j] = seg_start_pos - j
  134. pred_mask[b, j] = 1
  135. else:
  136. length = bt_i[j]
  137. j = j - bt_i[j]
  138. pred_mask[b, j] = 1
  139. pred[b, j] = length
  140. j = j - 1
  141. return torch.LongTensor(pred).to(logits.device), torch.LongTensor(pred_mask).to(logits.device)
  142. class FeatureFunMax(nn.Module):
  143. def __init__(self, hidden_size:int, L:int):
  144. """
  145. 用于计算semi-CRF特征的函数。给定batch_size x max_len x hidden_size形状的输入,输出为batch_size x max_len x L的
  146. 分数,以及batch_size x max_len的relay的分数。两者的区别参考论文 TODO 补充
  147. :param hidden_size: 输入特征的维度大小
  148. :param L: 不包含relay node的segment的长度大小。
  149. """
  150. super().__init__()
  151. self.end_fc = nn.Linear(hidden_size, 1, bias=False)
  152. self.whole_w = nn.Parameter(torch.randn(L, hidden_size))
  153. self.relay_fc = nn.Linear(hidden_size, 1)
  154. self.length_bias = nn.Parameter(torch.randn(L))
  155. self.L = L
  156. def forward(self, logits):
  157. """
  158. :param logits: batch_size x max_len x hidden_size
  159. :return: batch_size x max_len x L # 最后一维为左边segment的分数,0处为长度为1的segment
  160. batch_size x max_len, # 当前位置是接下来L-1个位置都不是终点的分数
  161. """
  162. batch_size, max_len, hidden_size = logits.size()
  163. # start_scores = self.start_fc(logits) # batch_size x max_len x 1 # 每个位置作为start的分数
  164. tmp = logits.new_zeros(batch_size, max_len+self.L-1, hidden_size)
  165. tmp[:, -max_len:] = logits
  166. # batch_size x max_len x hidden_size x (self.L) -> batch_size x max_len x (self.L) x hidden_size
  167. start_logits = tmp.unfold(dimension=1, size=self.L, step=1).transpose(2, 3).flip(dims=[2])
  168. end_scores = self.end_fc(logits) # batch_size x max_len x 1
  169. # 计算relay的特征
  170. relay_tmp = logits.new_zeros(batch_size, max_len, hidden_size)
  171. relay_tmp[:, :-self.L] = logits[:, self.L:]
  172. # batch_size x max_len x hidden_size
  173. relay_logits_max = torch.max(relay_tmp, logits) # end - start
  174. logits_max = torch.max(logits.unsqueeze(2), start_logits) # batch_size x max_len x L x hidden_size
  175. whole_scores = (logits_max*self.whole_w).sum(dim=-1) # batch_size x max_len x self.L
  176. # whole_scores = self.whole_fc().squeeze(-1) # bz x max_len x self.L
  177. # batch_size x max_len
  178. relay_scores = self.relay_fc(relay_logits_max).squeeze(-1)
  179. return whole_scores+end_scores+self.length_bias.view(1, 1, -1), relay_scores