You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DeepLSTM.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.init as init
  5. import torch.nn.functional as F
  6. from torch.autograd import Variable
  7. from torch.distributions import Bernoulli
  8. class DeepLSTM(nn.Module):
  9. def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True):
  10. super(DeepLSTM, self).__init__()
  11. self.fix_mask = fix_mask
  12. self.use_cuda = use_cuda
  13. self.input_size = input_size
  14. self.num_layers = num_layers
  15. self.hidden_size = hidden_size
  16. self.recurrent_dropout = recurrent_dropout
  17. self.lstms = nn.ModuleList([None] * self.num_layers)
  18. self.highway_gate_input = nn.ModuleList([None] * self.num_layers)
  19. self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers)
  20. self.highway_linear_input = nn.ModuleList([None] * self.num_layers)
  21. # self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size))
  22. # init.xavier_normal_(self._input_w)
  23. for l in range(self.num_layers):
  24. input_dim = input_size if l == 0 else hidden_size
  25. self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size)
  26. self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size)
  27. self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False)
  28. # logger.info("[INFO] Initing W for LSTM .......")
  29. for l in range(self.num_layers):
  30. if use_orthnormal_init:
  31. # logger.info("[INFO] Initing W using orthnormal init .......")
  32. init.orthogonal_(self.lstms[l].weight_ih)
  33. init.orthogonal_(self.lstms[l].weight_hh)
  34. init.orthogonal_(self.highway_gate_input[l].weight.data)
  35. init.orthogonal_(self.highway_gate_state[l].weight.data)
  36. init.orthogonal_(self.highway_linear_input[l].weight.data)
  37. else:
  38. # logger.info("[INFO] Initing W using xavier_normal .......")
  39. init_weight_value = 6.0
  40. init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value))
  41. init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value))
  42. init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value))
  43. init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value))
  44. init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value))
  45. def init_hidden(self, batch_size, hidden_size):
  46. # the first is the hidden h
  47. # the second is the cell c
  48. if self.use_cuda:
  49. return (torch.zeros(batch_size, hidden_size).cuda(),
  50. torch.zeros(batch_size, hidden_size).cuda())
  51. else:
  52. return (torch.zeros(batch_size, hidden_size),
  53. torch.zeros(batch_size, hidden_size))
  54. def forward(self, inputs, input_masks, Train):
  55. '''
  56. inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list)
  57. input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list)
  58. '''
  59. batch_size, seq_len = inputs[0].size(1), inputs[0].size(0)
  60. # inputs[0] = torch.matmul(inputs[0], self._input_w)
  61. # input_masks[0] = input_masks[0].unsqueeze(-1).expand(seq_len, batch_size, self.hidden_size)
  62. self.inputs = inputs
  63. self.input_masks = input_masks
  64. if self.fix_mask:
  65. self.output_dropout_layers = [None] * self.num_layers
  66. for l in range(self.num_layers):
  67. binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout
  68. # This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same.
  69. # from allennlp.nn.util.py
  70. self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout)
  71. if self.use_cuda:
  72. self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda()
  73. for l in range(self.num_layers):
  74. h, c = self.init_hidden(batch_size, self.hidden_size)
  75. outputs_list = []
  76. for t in range(len(self.inputs[l])):
  77. x = self.inputs[l][t]
  78. m = self.input_masks[l][t].float()
  79. h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size]
  80. r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h))
  81. lx = self.highway_linear_input[l](x) # [batch, hidden_size]
  82. h_temp = r * h_temp + (1 - r) * lx
  83. if Train:
  84. if self.fix_mask:
  85. h_temp = self.output_dropout_layers[l] * h_temp
  86. else:
  87. h_temp = F.dropout(h_temp, p=self.recurrent_dropout)
  88. h = m * h_temp + (1 - m) * h
  89. c = m * c_temp + (1 - m) * c
  90. outputs_list.append(h)
  91. outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size]
  92. self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size]
  93. self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0)
  94. self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size]
  95. # flip -2 layer
  96. # self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0)
  97. # concat last two layer
  98. # self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1)
  99. self.output_state = self.output_state[-1].transpose(0, 1)
  100. assert self.output_state.size() == (batch_size, seq_len, self.hidden_size)
  101. return self.output_state
  102. @staticmethod
  103. def flip(x, dim):
  104. xsize = x.size()
  105. dim = x.dim() + dim if dim < 0 else dim
  106. x = x.contiguous()
  107. x = x.view(-1, *xsize[dim:]).contiguous()
  108. x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1,
  109. -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
  110. return x.view(xsize)