You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dilated_cnn.py 4.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from fastNLP.modules.decoder import ConditionalRandomField
  5. from fastNLP.embeddings import Embedding
  6. from fastNLP.core.utils import seq_len_to_mask
  7. from fastNLP.core.const import Const as C
  8. class IDCNN(nn.Module):
  9. def __init__(self,
  10. init_embed,
  11. char_embed,
  12. num_cls,
  13. repeats, num_layers, num_filters, kernel_size,
  14. use_crf=False, use_projection=False, block_loss=False,
  15. input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0):
  16. super(IDCNN, self).__init__()
  17. self.word_embeddings = Embedding(init_embed)
  18. if char_embed is None:
  19. self.char_embeddings = None
  20. embedding_size = self.word_embeddings.embedding_dim
  21. else:
  22. self.char_embeddings = Embedding(char_embed)
  23. embedding_size = self.word_embeddings.embedding_dim + \
  24. self.char_embeddings.embedding_dim
  25. self.conv0 = nn.Sequential(
  26. nn.Conv1d(in_channels=embedding_size,
  27. out_channels=num_filters,
  28. kernel_size=kernel_size,
  29. stride=1, dilation=1,
  30. padding=kernel_size//2,
  31. bias=True),
  32. nn.ReLU(),
  33. )
  34. block = []
  35. for layer_i in range(num_layers):
  36. dilated = 2 ** layer_i if layer_i+1 < num_layers else 1
  37. block.append(nn.Conv1d(
  38. in_channels=num_filters,
  39. out_channels=num_filters,
  40. kernel_size=kernel_size,
  41. stride=1, dilation=dilated,
  42. padding=(kernel_size//2) * dilated,
  43. bias=True))
  44. block.append(nn.ReLU())
  45. self.block = nn.Sequential(*block)
  46. if use_projection:
  47. self.projection = nn.Sequential(
  48. nn.Conv1d(
  49. in_channels=num_filters,
  50. out_channels=num_filters//2,
  51. kernel_size=1,
  52. bias=True),
  53. nn.ReLU(),)
  54. encode_dim = num_filters // 2
  55. else:
  56. self.projection = None
  57. encode_dim = num_filters
  58. self.input_drop = nn.Dropout(input_dropout)
  59. self.hidden_drop = nn.Dropout(hidden_dropout)
  60. self.inner_drop = nn.Dropout(inner_dropout)
  61. self.repeats = repeats
  62. self.out_fc = nn.Conv1d(
  63. in_channels=encode_dim,
  64. out_channels=num_cls,
  65. kernel_size=1,
  66. bias=True)
  67. self.crf = ConditionalRandomField(
  68. num_tags=num_cls) if use_crf else None
  69. self.block_loss = block_loss
  70. self.reset_parameters()
  71. def reset_parameters(self):
  72. for m in self.modules():
  73. if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
  74. nn.init.xavier_normal_(m.weight, gain=1)
  75. if m.bias is not None:
  76. nn.init.normal_(m.bias, mean=0, std=0.01)
  77. def forward(self, words, seq_len, target=None, chars=None):
  78. if self.char_embeddings is None:
  79. x = self.word_embeddings(words)
  80. else:
  81. if chars is None:
  82. raise ValueError('must provide chars for model with char embedding')
  83. e1 = self.word_embeddings(words)
  84. e2 = self.char_embeddings(chars)
  85. x = torch.cat((e1, e2), dim=-1) # b,l,h
  86. mask = seq_len_to_mask(seq_len)
  87. x = x.transpose(1, 2) # b,h,l
  88. last_output = self.conv0(x)
  89. output = []
  90. for repeat in range(self.repeats):
  91. last_output = self.block(last_output)
  92. hidden = self.projection(last_output) if self.projection is not None else last_output
  93. output.append(self.out_fc(hidden))
  94. def compute_loss(y, t, mask):
  95. if self.crf is not None and target is not None:
  96. loss = self.crf(y.transpose(1, 2), t, mask)
  97. else:
  98. y.masked_fill_((mask.eq(False))[:,None,:], -100)
  99. # f_mask = mask.float()
  100. # t = f_mask * t + (1-f_mask) * -100
  101. loss = F.cross_entropy(y, t, ignore_index=-100)
  102. return loss
  103. if target is not None:
  104. if self.block_loss:
  105. losses = [compute_loss(o, target, mask) for o in output]
  106. loss = sum(losses)
  107. else:
  108. loss = compute_loss(output[-1], target, mask)
  109. else:
  110. loss = None
  111. scores = output[-1]
  112. if self.crf is not None:
  113. pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask)
  114. else:
  115. pred = scores.max(1)[1] * mask.long()
  116. return {
  117. C.LOSS: loss,
  118. C.OUTPUT: pred,
  119. }