You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

char_cnn.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. '''
  2. @author: https://github.com/ahmedbesbes/character-based-cnn
  3. 这里借鉴了上述链接中char-cnn model的代码,改动主要为将其改动为符合fastnlp的pipline
  4. '''
  5. import torch
  6. import torch.nn as nn
  7. from fastNLP.core.const import Const as C
  8. class CharacterLevelCNN(nn.Module):
  9. def __init__(self, args,embedding):
  10. super(CharacterLevelCNN, self).__init__()
  11. self.config=args.char_cnn_config
  12. self.embedding=embedding
  13. conv_layers = []
  14. for i, conv_layer_parameter in enumerate(self.config['model_parameters'][args.model_size]['conv']):
  15. if i == 0:
  16. #in_channels = args.number_of_characters + len(args.extra_characters)
  17. in_channels = args.embedding_dim
  18. out_channels = conv_layer_parameter[0]
  19. else:
  20. in_channels, out_channels = conv_layer_parameter[0], conv_layer_parameter[0]
  21. if conv_layer_parameter[2] != -1:
  22. conv_layer = nn.Sequential(nn.Conv1d(in_channels,
  23. out_channels,
  24. kernel_size=conv_layer_parameter[1], padding=0),
  25. nn.ReLU(),
  26. nn.MaxPool1d(conv_layer_parameter[2]))
  27. else:
  28. conv_layer = nn.Sequential(nn.Conv1d(in_channels,
  29. out_channels,
  30. kernel_size=conv_layer_parameter[1], padding=0),
  31. nn.ReLU())
  32. conv_layers.append(conv_layer)
  33. self.conv_layers = nn.ModuleList(conv_layers)
  34. input_shape = (args.batch_size, args.max_length,
  35. args.number_of_characters + len(args.extra_characters))
  36. dimension = self._get_conv_output(input_shape)
  37. print('dimension :', dimension)
  38. fc_layer_parameter = self.config['model_parameters'][args.model_size]['fc'][0]
  39. fc_layers = nn.ModuleList([
  40. nn.Sequential(
  41. nn.Linear(dimension, fc_layer_parameter), nn.Dropout(0.5)),
  42. nn.Sequential(nn.Linear(fc_layer_parameter,
  43. fc_layer_parameter), nn.Dropout(0.5)),
  44. nn.Linear(fc_layer_parameter, args.num_classes),
  45. ])
  46. self.fc_layers = fc_layers
  47. if args.model_size == 'small':
  48. self._create_weights(mean=0.0, std=0.05)
  49. elif args.model_size == 'large':
  50. self._create_weights(mean=0.0, std=0.02)
  51. def _create_weights(self, mean=0.0, std=0.05):
  52. for module in self.modules():
  53. if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
  54. module.weight.data.normal_(mean, std)
  55. def _get_conv_output(self, shape):
  56. input = torch.rand(shape)
  57. output = input.transpose(1, 2)
  58. # forward pass through conv layers
  59. for i in range(len(self.conv_layers)):
  60. output = self.conv_layers[i](output)
  61. output = output.view(output.size(0), -1)
  62. n_size = output.size(1)
  63. return n_size
  64. def forward(self, chars):
  65. input=self.embedding(chars)
  66. output = input.transpose(1, 2)
  67. # forward pass through conv layers
  68. for i in range(len(self.conv_layers)):
  69. output = self.conv_layers[i](output)
  70. output = output.view(output.size(0), -1)
  71. # forward pass through fc layers
  72. for i in range(len(self.fc_layers)):
  73. output = self.fc_layers[i](output)
  74. return {C.OUTPUT: output}