You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

awdlstm_module.py 3.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. 轻量封装的 Pytorch LSTM 模块.
  3. 可在 forward 时传入序列的长度, 自动对padding做合适的处理.
  4. """
  5. __all__ = [
  6. "LSTM"
  7. ]
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.utils.rnn as rnn
  11. from fastNLP.modules.utils import initial_parameter
  12. from torch import autograd
  13. from .weight_drop import WeightDrop
  14. class LSTM(nn.Module):
  15. """
  16. 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`
  17. LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
  18. 为1; 且可以应对DataParallel中LSTM的使用问题。
  19. :param input_size: 输入 `x` 的特征维度
  20. :param hidden_size: 隐状态 `h` 的特征维度.
  21. :param num_layers: rnn的层数. Default: 1
  22. :param dropout: 层间dropout概率. Default: 0
  23. :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
  24. :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
  25. :(batch, seq, feature). Default: ``False``
  26. :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
  27. """
  28. def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
  29. bidirectional=False, bias=True, wdrop=0.5):
  30. super(LSTM, self).__init__()
  31. self.batch_first = batch_first
  32. self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
  33. dropout=dropout, bidirectional=bidirectional)
  34. self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop)
  35. self.init_param()
  36. def init_param(self):
  37. for name, param in self.named_parameters():
  38. if 'bias' in name:
  39. # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
  40. param.data.fill_(0)
  41. n = param.size(0)
  42. start, end = n // 4, n // 2
  43. param.data[start:end].fill_(1)
  44. else:
  45. nn.init.xavier_uniform_(param)
  46. def forward(self, x, seq_len=None, h0=None, c0=None):
  47. """
  48. :param x: [batch, seq_len, input_size] 输入序列
  49. :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None``
  50. :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None``
  51. :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None``
  52. :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列
  53. 和 [batch, hidden_size*num_direction] 最后时刻隐状态.
  54. """
  55. batch_size, max_len, _ = x.size()
  56. if h0 is not None and c0 is not None:
  57. hx = (h0, c0)
  58. else:
  59. hx = None
  60. if seq_len is not None and not isinstance(x, rnn.PackedSequence):
  61. sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
  62. if self.batch_first:
  63. x = x[sort_idx]
  64. else:
  65. x = x[:, sort_idx]
  66. x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
  67. output, hx = self.lstm(x, hx) # -> [N,L,C]
  68. output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
  69. _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
  70. if self.batch_first:
  71. output = output[unsort_idx]
  72. else:
  73. output = output[:, unsort_idx]
  74. else:
  75. output, hx = self.lstm(x, hx)
  76. return output, hx