You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CharParser.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from fastNLP.models.biaffine_parser import BiaffineParser
  2. from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear
  3. import numpy as np
  4. import torch
  5. from torch import nn
  6. from torch.nn import functional as F
  7. from fastNLP.modules.dropout import TimestepDropout
  8. from fastNLP.modules.encoder.variational_rnn import VarLSTM
  9. from fastNLP import seq_len_to_mask
  10. from fastNLP.embeddings import Embedding
  11. def drop_input_independent(word_embeddings, dropout_emb):
  12. batch_size, seq_length, _ = word_embeddings.size()
  13. word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb)
  14. word_masks = torch.bernoulli(word_masks)
  15. word_masks = word_masks.unsqueeze(dim=2)
  16. word_embeddings = word_embeddings * word_masks
  17. return word_embeddings
  18. class CharBiaffineParser(BiaffineParser):
  19. def __init__(self, char_vocab_size,
  20. emb_dim,
  21. bigram_vocab_size,
  22. trigram_vocab_size,
  23. num_label,
  24. rnn_layers=3,
  25. rnn_hidden_size=800, #单向的数量
  26. arc_mlp_size=500,
  27. label_mlp_size=100,
  28. dropout=0.3,
  29. encoder='lstm',
  30. use_greedy_infer=False,
  31. app_index = 0,
  32. pre_chars_embed=None,
  33. pre_bigrams_embed=None,
  34. pre_trigrams_embed=None):
  35. super(BiaffineParser, self).__init__()
  36. rnn_out_size = 2 * rnn_hidden_size
  37. self.char_embed = Embedding((char_vocab_size, emb_dim))
  38. self.bigram_embed = Embedding((bigram_vocab_size, emb_dim))
  39. self.trigram_embed = Embedding((trigram_vocab_size, emb_dim))
  40. if pre_chars_embed:
  41. self.pre_char_embed = Embedding(pre_chars_embed)
  42. self.pre_char_embed.requires_grad = False
  43. if pre_bigrams_embed:
  44. self.pre_bigram_embed = Embedding(pre_bigrams_embed)
  45. self.pre_bigram_embed.requires_grad = False
  46. if pre_trigrams_embed:
  47. self.pre_trigram_embed = Embedding(pre_trigrams_embed)
  48. self.pre_trigram_embed.requires_grad = False
  49. self.timestep_drop = TimestepDropout(dropout)
  50. self.encoder_name = encoder
  51. if encoder == 'var-lstm':
  52. self.encoder = VarLSTM(input_size=emb_dim*3,
  53. hidden_size=rnn_hidden_size,
  54. num_layers=rnn_layers,
  55. bias=True,
  56. batch_first=True,
  57. input_dropout=dropout,
  58. hidden_dropout=dropout,
  59. bidirectional=True)
  60. elif encoder == 'lstm':
  61. self.encoder = nn.LSTM(input_size=emb_dim*3,
  62. hidden_size=rnn_hidden_size,
  63. num_layers=rnn_layers,
  64. bias=True,
  65. batch_first=True,
  66. dropout=dropout,
  67. bidirectional=True)
  68. else:
  69. raise ValueError('unsupported encoder type: {}'.format(encoder))
  70. self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
  71. nn.LeakyReLU(0.1),
  72. TimestepDropout(p=dropout),)
  73. self.arc_mlp_size = arc_mlp_size
  74. self.label_mlp_size = label_mlp_size
  75. self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
  76. self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
  77. self.use_greedy_infer = use_greedy_infer
  78. self.reset_parameters()
  79. self.dropout = dropout
  80. self.app_index = app_index
  81. self.num_label = num_label
  82. if self.app_index != 0:
  83. raise ValueError("现在app_index必须等于0")
  84. def reset_parameters(self):
  85. for name, m in self.named_modules():
  86. if 'embed' in name:
  87. pass
  88. elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'):
  89. pass
  90. else:
  91. for p in m.parameters():
  92. if len(p.size())>1:
  93. nn.init.xavier_normal_(p, gain=0.1)
  94. else:
  95. nn.init.uniform_(p, -0.1, 0.1)
  96. def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None,
  97. pre_trigrams=None):
  98. """
  99. max_len是包含root的
  100. :param chars: batch_size x max_len
  101. :param ngrams: batch_size x max_len*ngram_per_char
  102. :param seq_lens: batch_size
  103. :param gold_heads: batch_size x max_len
  104. :param pre_chars: batch_size x max_len
  105. :param pre_ngrams: batch_size x max_len*ngram_per_char
  106. :return dict: parsing results
  107. arc_pred: [batch_size, seq_len, seq_len]
  108. label_pred: [batch_size, seq_len, seq_len]
  109. mask: [batch_size, seq_len]
  110. head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
  111. """
  112. # prepare embeddings
  113. batch_size, seq_len = chars.shape
  114. # print('forward {} {}'.format(batch_size, seq_len))
  115. # get sequence mask
  116. mask = seq_len_to_mask(seq_lens).long()
  117. chars = self.char_embed(chars) # [N,L] -> [N,L,C_0]
  118. bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1]
  119. trigrams = self.trigram_embed(trigrams)
  120. if pre_chars is not None:
  121. pre_chars = self.pre_char_embed(pre_chars)
  122. # pre_chars = self.pre_char_fc(pre_chars)
  123. chars = pre_chars + chars
  124. if pre_bigrams is not None:
  125. pre_bigrams = self.pre_bigram_embed(pre_bigrams)
  126. # pre_bigrams = self.pre_bigram_fc(pre_bigrams)
  127. bigrams = bigrams + pre_bigrams
  128. if pre_trigrams is not None:
  129. pre_trigrams = self.pre_trigram_embed(pre_trigrams)
  130. # pre_trigrams = self.pre_trigram_fc(pre_trigrams)
  131. trigrams = trigrams + pre_trigrams
  132. x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C]
  133. # encoder, extract features
  134. if self.training:
  135. x = drop_input_independent(x, self.dropout)
  136. sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
  137. x = x[sort_idx]
  138. x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
  139. feat, _ = self.encoder(x) # -> [N,L,C]
  140. feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
  141. _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
  142. feat = feat[unsort_idx]
  143. feat = self.timestep_drop(feat)
  144. # for arc biaffine
  145. # mlp, reduce dim
  146. feat = self.mlp(feat)
  147. arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
  148. arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
  149. label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]
  150. # biaffine arc classifier
  151. arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
  152. # use gold or predicted arc to predict label
  153. if gold_heads is None or not self.training:
  154. # use greedy decoding in training
  155. if self.training or self.use_greedy_infer:
  156. heads = self.greedy_decoder(arc_pred, mask)
  157. else:
  158. heads = self.mst_decoder(arc_pred, mask)
  159. head_pred = heads
  160. else:
  161. assert self.training # must be training mode
  162. if gold_heads is None:
  163. heads = self.greedy_decoder(arc_pred, mask)
  164. head_pred = heads
  165. else:
  166. head_pred = None
  167. heads = gold_heads
  168. # heads: batch_size x max_len
  169. batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1)
  170. label_head = label_head[batch_range, heads].contiguous()
  171. label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label]
  172. # 这里限制一下,只有当head为下一个时,才能预测app这个label
  173. arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\
  174. .repeat(batch_size, 1) # batch_size x max_len
  175. app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app
  176. app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label)
  177. app_masks[:, :, 1:] = 0
  178. label_pred = label_pred.masked_fill(app_masks, -np.inf)
  179. res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
  180. if head_pred is not None:
  181. res_dict['head_pred'] = head_pred
  182. return res_dict
  183. @staticmethod
  184. def loss(arc_pred, label_pred, arc_true, label_true, mask):
  185. """
  186. Compute loss.
  187. :param arc_pred: [batch_size, seq_len, seq_len]
  188. :param label_pred: [batch_size, seq_len, n_tags]
  189. :param arc_true: [batch_size, seq_len]
  190. :param label_true: [batch_size, seq_len]
  191. :param mask: [batch_size, seq_len]
  192. :return: loss value
  193. """
  194. batch_size, seq_len, _ = arc_pred.shape
  195. flip_mask = (mask.eq(False))
  196. # _arc_pred = arc_pred.clone()
  197. _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))
  198. arc_true.data[:, 0].fill_(-1)
  199. label_true.data[:, 0].fill_(-1)
  200. arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1)
  201. label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1)
  202. return arc_nll + label_nll
  203. def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams):
  204. """
  205. max_len是包含root的
  206. :param chars: batch_size x max_len
  207. :param ngrams: batch_size x max_len*ngram_per_char
  208. :param seq_lens: batch_size
  209. :param pre_chars: batch_size x max_len
  210. :param pre_ngrams: batch_size x max_len*ngram_per_cha
  211. :return:
  212. """
  213. res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams,
  214. pre_trigrams=pre_trigrams, gold_heads=None)
  215. output = {}
  216. output['arc_pred'] = res.pop('head_pred')
  217. _, label_pred = res.pop('label_pred').max(2)
  218. output['label_pred'] = label_pred
  219. return output
  220. class CharParser(nn.Module):
  221. def __init__(self, char_vocab_size,
  222. emb_dim,
  223. bigram_vocab_size,
  224. trigram_vocab_size,
  225. num_label,
  226. rnn_layers=3,
  227. rnn_hidden_size=400, #单向的数量
  228. arc_mlp_size=500,
  229. label_mlp_size=100,
  230. dropout=0.3,
  231. encoder='var-lstm',
  232. use_greedy_infer=False,
  233. app_index = 0,
  234. pre_chars_embed=None,
  235. pre_bigrams_embed=None,
  236. pre_trigrams_embed=None):
  237. super().__init__()
  238. self.parser = CharBiaffineParser(char_vocab_size,
  239. emb_dim,
  240. bigram_vocab_size,
  241. trigram_vocab_size,
  242. num_label,
  243. rnn_layers,
  244. rnn_hidden_size, #单向的数量
  245. arc_mlp_size,
  246. label_mlp_size,
  247. dropout,
  248. encoder,
  249. use_greedy_infer,
  250. app_index,
  251. pre_chars_embed=pre_chars_embed,
  252. pre_bigrams_embed=pre_bigrams_embed,
  253. pre_trigrams_embed=pre_trigrams_embed)
  254. def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None,
  255. pre_trigrams=None):
  256. res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars,
  257. pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
  258. arc_pred = res_dict['arc_pred']
  259. label_pred = res_dict['label_pred']
  260. masks = res_dict['mask']
  261. loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks)
  262. return {'loss': loss}
  263. def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None):
  264. res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars,
  265. pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
  266. output = {}
  267. output['head_preds'] = res.pop('head_pred')
  268. _, label_pred = res.pop('label_pred').max(2)
  269. output['label_preds'] = label_pred
  270. return output