import torch from torch.nn import Parameter from functools import wraps class WeightDrop(torch.nn.Module): def __init__(self, module, weights, dropout=0, variational=False): super(WeightDrop, self).__init__() self.module = module self.weights = weights self.dropout = dropout self.variational = variational self._setup() def widget_demagnetizer_y2k_edition(*args, **kwargs): # We need to replace flatten_parameters with a nothing function # It must be a function rather than a lambda as otherwise pickling explodes # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! # (╯°□°)╯︵ ┻━┻ return def _setup(self): # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN if issubclass(type(self.module), torch.nn.RNNBase): self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition for name_w in self.weights: print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) w = getattr(self.module, name_w) del self.module._parameters[name_w] self.module.register_parameter(name_w + '_raw', Parameter(w.data)) def _setweights(self): for name_w in self.weights: raw_w = getattr(self.module, name_w + '_raw') w = None if self.variational: mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) if raw_w.is_cuda: mask = mask.cuda() mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) w = mask.expand_as(raw_w) * raw_w else: w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) setattr(self.module, name_w, w) def forward(self, *args): self._setweights() return self.module.forward(*args) if __name__ == '__main__': import torch from weight_drop import WeightDrop # Input is (seq, batch, input) x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() h0 = None ### print('Testing WeightDrop') print('=-=-=-=-=-=-=-=-=-=') ### print('Testing WeightDrop with Linear') lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) lin.cuda() run1 = [x.sum() for x in lin(x).data] run2 = [x.sum() for x in lin(x).data] print('All items should be different') print('Run 1:', run1) print('Run 2:', run2) assert run1[0] != run2[0] assert run1[1] != run2[1] print('---') ### print('Testing WeightDrop with LSTM') wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) wdrnn.cuda() run1 = [x.sum() for x in wdrnn(x, h0)[0].data] run2 = [x.sum() for x in wdrnn(x, h0)[0].data] print('First timesteps should be equal, all others should differ') print('Run 1:', run1) print('Run 2:', run2) # First time step, not influenced by hidden to hidden weights, should be equal assert run1[0] == run2[0] # Second step should not assert run1[1] != run2[1] print('---')