You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weight_drop.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. from torch.nn import Parameter
  3. from functools import wraps
  4. class WeightDrop(torch.nn.Module):
  5. def __init__(self, module, weights, dropout=0, variational=False):
  6. super(WeightDrop, self).__init__()
  7. self.module = module
  8. self.weights = weights
  9. self.dropout = dropout
  10. self.variational = variational
  11. self._setup()
  12. def widget_demagnetizer_y2k_edition(*args, **kwargs):
  13. # We need to replace flatten_parameters with a nothing function
  14. # It must be a function rather than a lambda as otherwise pickling explodes
  15. # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
  16. # (╯°□°)╯︵ ┻━┻
  17. return
  18. def _setup(self):
  19. # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
  20. if issubclass(type(self.module), torch.nn.RNNBase):
  21. self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
  22. for name_w in self.weights:
  23. print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
  24. w = getattr(self.module, name_w)
  25. del self.module._parameters[name_w]
  26. self.module.register_parameter(name_w + '_raw', Parameter(w.data))
  27. def _setweights(self):
  28. for name_w in self.weights:
  29. raw_w = getattr(self.module, name_w + '_raw')
  30. w = None
  31. if self.variational:
  32. mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
  33. if raw_w.is_cuda: mask = mask.cuda()
  34. mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
  35. w = mask.expand_as(raw_w) * raw_w
  36. else:
  37. w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
  38. setattr(self.module, name_w, w)
  39. def forward(self, *args):
  40. self._setweights()
  41. return self.module.forward(*args)
  42. if __name__ == '__main__':
  43. import torch
  44. from weight_drop import WeightDrop
  45. # Input is (seq, batch, input)
  46. x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
  47. h0 = None
  48. ###
  49. print('Testing WeightDrop')
  50. print('=-=-=-=-=-=-=-=-=-=')
  51. ###
  52. print('Testing WeightDrop with Linear')
  53. lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
  54. lin.cuda()
  55. run1 = [x.sum() for x in lin(x).data]
  56. run2 = [x.sum() for x in lin(x).data]
  57. print('All items should be different')
  58. print('Run 1:', run1)
  59. print('Run 2:', run2)
  60. assert run1[0] != run2[0]
  61. assert run1[1] != run2[1]
  62. print('---')
  63. ###
  64. print('Testing WeightDrop with LSTM')
  65. wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)
  66. wdrnn.cuda()
  67. run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
  68. run2 = [x.sum() for x in wdrnn(x, h0)[0].data]
  69. print('First timesteps should be equal, all others should differ')
  70. print('Run 1:', run1)
  71. print('Run 2:', run2)
  72. # First time step, not influenced by hidden to hidden weights, should be equal
  73. assert run1[0] == run2[0]
  74. # Second step should not
  75. assert run1[1] != run2[1]
  76. print('---')