import numpy as np import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F from torch.autograd import Variable from torch.distributions import Bernoulli class DeepLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, recurrent_dropout, use_orthnormal_init=True, fix_mask=True, use_cuda=True): super(DeepLSTM, self).__init__() self.fix_mask = fix_mask self.use_cuda = use_cuda self.input_size = input_size self.num_layers = num_layers self.hidden_size = hidden_size self.recurrent_dropout = recurrent_dropout self.lstms = nn.ModuleList([None] * self.num_layers) self.highway_gate_input = nn.ModuleList([None] * self.num_layers) self.highway_gate_state = nn.ModuleList([nn.Linear(hidden_size, hidden_size)] * self.num_layers) self.highway_linear_input = nn.ModuleList([None] * self.num_layers) # self._input_w = nn.Parameter(torch.Tensor(input_size, hidden_size)) # init.xavier_normal_(self._input_w) for l in range(self.num_layers): input_dim = input_size if l == 0 else hidden_size self.lstms[l] = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_size) self.highway_gate_input[l] = nn.Linear(input_dim, hidden_size) self.highway_linear_input[l] = nn.Linear(input_dim, hidden_size, bias=False) # logger.info("[INFO] Initing W for LSTM .......") for l in range(self.num_layers): if use_orthnormal_init: # logger.info("[INFO] Initing W using orthnormal init .......") init.orthogonal_(self.lstms[l].weight_ih) init.orthogonal_(self.lstms[l].weight_hh) init.orthogonal_(self.highway_gate_input[l].weight.data) init.orthogonal_(self.highway_gate_state[l].weight.data) init.orthogonal_(self.highway_linear_input[l].weight.data) else: # logger.info("[INFO] Initing W using xavier_normal .......") init_weight_value = 6.0 init.xavier_normal_(self.lstms[l].weight_ih, gain=np.sqrt(init_weight_value)) init.xavier_normal_(self.lstms[l].weight_hh, gain=np.sqrt(init_weight_value)) init.xavier_normal_(self.highway_gate_input[l].weight.data, gain=np.sqrt(init_weight_value)) init.xavier_normal_(self.highway_gate_state[l].weight.data, gain=np.sqrt(init_weight_value)) init.xavier_normal_(self.highway_linear_input[l].weight.data, gain=np.sqrt(init_weight_value)) def init_hidden(self, batch_size, hidden_size): # the first is the hidden h # the second is the cell c if self.use_cuda: return (torch.zeros(batch_size, hidden_size).cuda(), torch.zeros(batch_size, hidden_size).cuda()) else: return (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size)) def forward(self, inputs, input_masks, Train): ''' inputs: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) input_masks: [[seq_len, batch, Co * kernel_sizes], n_layer * [None]] (list) ''' batch_size, seq_len = inputs[0].size(1), inputs[0].size(0) # inputs[0] = torch.matmul(inputs[0], self._input_w) # input_masks[0] = input_masks[0].unsqueeze(-1).expand(seq_len, batch_size, self.hidden_size) self.inputs = inputs self.input_masks = input_masks if self.fix_mask: self.output_dropout_layers = [None] * self.num_layers for l in range(self.num_layers): binary_mask = torch.rand((batch_size, self.hidden_size)) > self.recurrent_dropout # This scaling ensures expected values and variances of the output of applying this mask and the original tensor are the same. # from allennlp.nn.util.py self.output_dropout_layers[l] = binary_mask.float().div(1.0 - self.recurrent_dropout) if self.use_cuda: self.output_dropout_layers[l] = self.output_dropout_layers[l].cuda() for l in range(self.num_layers): h, c = self.init_hidden(batch_size, self.hidden_size) outputs_list = [] for t in range(len(self.inputs[l])): x = self.inputs[l][t] m = self.input_masks[l][t].float() h_temp, c_temp = self.lstms[l].forward(x, (h, c)) # [batch, hidden_size] r = torch.sigmoid(self.highway_gate_input[l](x) + self.highway_gate_state[l](h)) lx = self.highway_linear_input[l](x) # [batch, hidden_size] h_temp = r * h_temp + (1 - r) * lx if Train: if self.fix_mask: h_temp = self.output_dropout_layers[l] * h_temp else: h_temp = F.dropout(h_temp, p=self.recurrent_dropout) h = m * h_temp + (1 - m) * h c = m * c_temp + (1 - m) * c outputs_list.append(h) outputs = torch.stack(outputs_list, 0) # [seq_len, batch, hidden_size] self.inputs[l + 1] = DeepLSTM.flip(outputs, 0) # reverse [seq_len, batch, hidden_size] self.input_masks[l + 1] = DeepLSTM.flip(self.input_masks[l], 0) self.output_state = self.inputs # num_layers * [seq_len, batch, hidden_size] # flip -2 layer # self.output_state[-2] = DeepLSTM.flip(self.output_state[-2], 0) # concat last two layer # self.output_state = torch.cat([self.output_state[-1], self.output_state[-2]], dim=-1).transpose(0, 1) self.output_state = self.output_state[-1].transpose(0, 1) assert self.output_state.size() == (batch_size, seq_len, self.hidden_size) return self.output_state @staticmethod def flip(x, dim): xsize = x.size() dim = x.dim() + dim if dim < 0 else dim x = x.contiguous() x = x.view(-1, *xsize[dim:]).contiguous() x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] return x.view(xsize)