You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_7_modules_models.rst 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. ======================================
  2. Modules 和 models 的教程
  3. ======================================
  4. :mod:`~fastNLP.modules` 和 :mod:`~fastNLP.models` 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。
  5. 下面我们会分三节介绍编写构建模型的具体方法。
  6. ----------------------
  7. 使用 models 中的模型
  8. ----------------------
  9. fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、
  10. :class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。
  11. 以 :class:`~fastNLP.models.CNNText` 为例,我们看一个简单的文本分类的任务的实现过程。
  12. 首先是数据读入和处理部分,这里的代码和 :doc:`快速入门 </user/quickstart>` 中一致。
  13. .. code-block:: python
  14. from fastNLP.io import CSVLoader
  15. from fastNLP import Vocabulary, CrossEntropyLoss, AccuracyMetric
  16. loader = CSVLoader(headers=('raw_sentence', 'label'), sep='\t')
  17. dataset = loader.load("./sample_data/tutorial_sample_dataset.csv")
  18. dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='sentence')
  19. dataset.apply_field(lambda x: x.split(), field_name='sentence', new_field_name='words', is_input=True)
  20. dataset.apply(lambda x: int(x['label']), new_field_name='target', is_target=True)
  21. train_dev_data, test_data = dataset.split(0.1)
  22. train_data, dev_data = train_dev_data.split(0.1)
  23. vocab = Vocabulary(min_freq=2).from_dataset(train_data, field_name='words')
  24. vocab.index_dataset(train_data, dev_data, test_data, field_name='words', new_field_name='words')
  25. 然后我们从 :mod:`~fastNLP.models` 中导入 ``CNNText`` 模型,用它进行训练
  26. .. code-block:: python
  27. from fastNLP.models import CNNText
  28. from fastNLP import Trainer
  29. model_cnn = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
  30. trainer = Trainer(model=model_cnn, train_data=train_data, dev_data=dev_data,
  31. loss=CrossEntropyLoss(), metrics=AccuracyMetric())
  32. trainer.train()
  33. 在 iPython 环境输入 `model_cnn` ,我们可以看到 ``model_cnn`` 的网络结构
  34. .. parsed-literal::
  35. CNNText(
  36. (embed): Embedding(
  37. 169, 50
  38. (dropout): Dropout(p=0.0)
  39. )
  40. (conv_pool): ConvMaxpool(
  41. (convs): ModuleList(
  42. (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))
  43. (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))
  44. (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))
  45. )
  46. )
  47. (dropout): Dropout(p=0.1)
  48. (fc): Linear(in_features=12, out_features=5, bias=True)
  49. )
  50. ----------------------------
  51. 使用 nn.torch 编写模型
  52. ----------------------------
  53. FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模型的常见方法不同,
  54. 用于 fastNLP 的模型中 forward 函数需要返回一个字典,字典中至少需要包含 ``pred`` 这个字段。
  55. 下面是使用 pyTorch 中的 torch.nn 模块编写的文本分类,注意观察代码中标注的向量维度。
  56. 由于 pyTorch 使用了约定俗成的维度设置,使得 forward 中需要多次处理维度顺序
  57. .. code-block:: python
  58. import torch
  59. import torch.nn as nn
  60. class LSTMText(nn.Module):
  61. def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
  62. super().__init__()
  63. self.embedding = nn.Embedding(vocab_size, embedding_dim)
  64. self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout)
  65. self.fc = nn.Linear(hidden_dim * 2, output_dim)
  66. self.dropout = nn.Dropout(dropout)
  67. def forward(self, words):
  68. # (input) words : (batch_size, seq_len)
  69. words = words.permute(1,0)
  70. # words : (seq_len, batch_size)
  71. embedded = self.dropout(self.embedding(words))
  72. # embedded : (seq_len, batch_size, embedding_dim)
  73. output, (hidden, cell) = self.lstm(embedded)
  74. # output: (seq_len, batch_size, hidden_dim * 2)
  75. # hidden: (num_layers * 2, batch_size, hidden_dim)
  76. # cell: (num_layers * 2, batch_size, hidden_dim)
  77. hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
  78. hidden = self.dropout(hidden)
  79. # hidden: (batch_size, hidden_dim * 2)
  80. pred = self.fc(hidden.squeeze(0))
  81. # result: (batch_size, output_dim)
  82. return {"pred":pred}
  83. 我们同样可以在 iPython 环境中查看这个模型的网络结构
  84. .. parsed-literal::
  85. LSTMText(
  86. (embedding): Embedding(169, 50)
  87. (lstm): LSTM(50, 64, num_layers=2, dropout=0.5, bidirectional=True)
  88. (fc): Linear(in_features=128, out_features=5, bias=True)
  89. (dropout): Dropout(p=0.5)
  90. )
  91. ----------------------------
  92. 使用 modules 编写模型
  93. ----------------------------
  94. 下面我们使用 :mod:`fastNLP.modules` 中的组件来构建同样的网络。由于 fastNLP 统一把 ``batch_size`` 放在第一维,
  95. 在编写代码的过程中会有一定的便利。
  96. .. code-block:: python
  97. from fastNLP.modules import Embedding, LSTM, MLP
  98. class Model(nn.Module):
  99. def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):
  100. super().__init__()
  101. self.embedding = Embedding((vocab_size, embedding_dim))
  102. self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)
  103. self.mlp = MLP([hidden_dim*2,output_dim], dropout=dropout)
  104. def forward(self, words):
  105. embedded = self.embedding(words)
  106. _,(hidden,_) = self.lstm(embedded)
  107. pred = self.mlp(torch.cat((hidden[-1],hidden[-2]),dim=1))
  108. return {"pred":pred}
  109. 我们自己编写模型的网络结构如下
  110. .. parsed-literal::
  111. Model(
  112. (embedding): Embedding(
  113. 169, 50
  114. (dropout): Dropout(p=0.0)
  115. )
  116. (lstm): LSTM(
  117. (lstm): LSTM(50, 64, num_layers=2, batch_first=True, bidirectional=True)
  118. )
  119. (mlp): MLP(
  120. (hiddens): ModuleList()
  121. (output): Linear(in_features=128, out_features=5, bias=True)
  122. (dropout): Dropout(p=0.5)
  123. )
  124. )