You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

awdlstm_module.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """
  2. 轻量封装的 Pytorch LSTM 模块.
  3. 可在 forward 时传入序列的长度, 自动对padding做合适的处理.
  4. """
  5. __all__ = [
  6. "LSTM"
  7. ]
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.utils.rnn as rnn
  11. from fastNLP.modules.utils import initial_parameter
  12. from torch import autograd
  13. from .weight_drop import WeightDrop
  14. class LSTM(nn.Module):
  15. """
  16. LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
  17. 为1; 且可以应对DataParallel中LSTM的使用问题。
  18. :param input_size: 输入 `x` 的特征维度
  19. :param hidden_size: 隐状态 `h` 的特征维度.
  20. :param num_layers: rnn的层数. Default: 1
  21. :param dropout: 层间dropout概率. Default: 0
  22. :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
  23. :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
  24. :(batch, seq, feature). Default: ``False``
  25. :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
  26. """
  27. def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
  28. bidirectional=False, bias=True, wdrop=0.5):
  29. super(LSTM, self).__init__()
  30. self.batch_first = batch_first
  31. self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
  32. dropout=dropout, bidirectional=bidirectional)
  33. self.lstm = WeightDrop(self.lstm, ['weight_hh_l0'], dropout=wdrop)
  34. self.init_param()
  35. def init_param(self):
  36. for name, param in self.named_parameters():
  37. if 'bias' in name:
  38. # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
  39. param.data.fill_(0)
  40. n = param.size(0)
  41. start, end = n // 4, n // 2
  42. param.data[start:end].fill_(1)
  43. else:
  44. nn.init.xavier_uniform_(param)
  45. def forward(self, x, seq_len=None, h0=None, c0=None):
  46. """
  47. :param x: [batch, seq_len, input_size] 输入序列
  48. :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None``
  49. :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None``
  50. :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None``
  51. :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列
  52. 和 [batch, hidden_size*num_direction] 最后时刻隐状态.
  53. """
  54. batch_size, max_len, _ = x.size()
  55. if h0 is not None and c0 is not None:
  56. hx = (h0, c0)
  57. else:
  58. hx = None
  59. if seq_len is not None and not isinstance(x, rnn.PackedSequence):
  60. sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
  61. if self.batch_first:
  62. x = x[sort_idx]
  63. else:
  64. x = x[:, sort_idx]
  65. x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
  66. output, hx = self.lstm(x, hx) # -> [N,L,C]
  67. output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
  68. _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
  69. if self.batch_first:
  70. output = output[unsort_idx]
  71. else:
  72. output = output[:, unsort_idx]
  73. else:
  74. output, hx = self.lstm(x, hx)
  75. return output, hx