You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

enas_model.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # Code Modified from https://github.com/carpedm20/ENAS-pytorch
  2. """Module containing the shared RNN model."""
  3. import collections
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn
  8. from torch.autograd import Variable
  9. import fastNLP.automl.enas_utils as utils
  10. from fastNLP.models.base_model import BaseModel
  11. def _get_dropped_weights(w_raw, dropout_p, is_training):
  12. """Drops out weights to implement DropConnect.
  13. Args:
  14. w_raw: Full, pre-dropout, weights to be dropped out.
  15. dropout_p: Proportion of weights to drop out.
  16. is_training: True iff _shared_ model is training.
  17. Returns:
  18. The dropped weights.
  19. Why does torch.nn.functional.dropout() return:
  20. 1. `torch.autograd.Variable()` on the training loop
  21. 2. `torch.nn.Parameter()` on the controller or eval loop, when
  22. training = False...
  23. Even though the call to `_setweights` in the Smerity repo's
  24. `weight_drop.py` does not have this behaviour, and `F.dropout` always
  25. returns `torch.autograd.Variable` there, even when `training=False`?
  26. The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
  27. """
  28. dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)
  29. if isinstance(dropped_w, torch.nn.Parameter):
  30. dropped_w = dropped_w.clone()
  31. return dropped_w
  32. class EmbeddingDropout(torch.nn.Embedding):
  33. """Class for dropping out embeddings by zero'ing out parameters in the
  34. embedding matrix.
  35. This is equivalent to dropping out particular words, e.g., in the sentence
  36. 'the quick brown fox jumps over the lazy dog', dropping out 'the' would
  37. lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
  38. embedding vector space).
  39. See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
  40. Networks', (Gal and Ghahramani, 2016).
  41. """
  42. def __init__(self,
  43. num_embeddings,
  44. embedding_dim,
  45. max_norm=None,
  46. norm_type=2,
  47. scale_grad_by_freq=False,
  48. sparse=False,
  49. dropout=0.1,
  50. scale=None):
  51. """Embedding constructor.
  52. Args:
  53. dropout: Dropout probability.
  54. scale: Used to scale parameters of embedding weight matrix that are
  55. not dropped out. Note that this is _in addition_ to the
  56. `1/(1 - dropout)` scaling.
  57. See `torch.nn.Embedding` for remaining arguments.
  58. """
  59. torch.nn.Embedding.__init__(self,
  60. num_embeddings=num_embeddings,
  61. embedding_dim=embedding_dim,
  62. max_norm=max_norm,
  63. norm_type=norm_type,
  64. scale_grad_by_freq=scale_grad_by_freq,
  65. sparse=sparse)
  66. self.dropout = dropout
  67. assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
  68. 'and < 1.0')
  69. self.scale = scale
  70. def forward(self, inputs): # pylint:disable=arguments-differ
  71. """Embeds `inputs` with the dropped out embedding weight matrix."""
  72. if self.training:
  73. dropout = self.dropout
  74. else:
  75. dropout = 0
  76. if dropout:
  77. mask = self.weight.data.new(self.weight.size(0), 1)
  78. mask.bernoulli_(1 - dropout)
  79. mask = mask.expand_as(self.weight)
  80. mask = mask / (1 - dropout)
  81. masked_weight = self.weight * Variable(mask)
  82. else:
  83. masked_weight = self.weight
  84. if self.scale and self.scale != 1:
  85. masked_weight = masked_weight * self.scale
  86. return F.embedding(inputs,
  87. masked_weight,
  88. max_norm=self.max_norm,
  89. norm_type=self.norm_type,
  90. scale_grad_by_freq=self.scale_grad_by_freq,
  91. sparse=self.sparse)
  92. class LockedDropout(nn.Module):
  93. # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
  94. def __init__(self):
  95. super().__init__()
  96. def forward(self, x, dropout=0.5):
  97. if not self.training or not dropout:
  98. return x
  99. m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
  100. mask = Variable(m, requires_grad=False) / (1 - dropout)
  101. mask = mask.expand_as(x)
  102. return mask * x
  103. class ENASModel(BaseModel):
  104. """Shared RNN model."""
  105. def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000):
  106. super(ENASModel, self).__init__()
  107. self.use_cuda = cuda
  108. self.shared_hid = shared_hid
  109. self.num_blocks = num_blocks
  110. self.decoder = nn.Linear(self.shared_hid, num_classes)
  111. self.encoder = EmbeddingDropout(embed_num,
  112. shared_embed,
  113. dropout=0.1)
  114. self.lockdrop = LockedDropout()
  115. self.dag = None
  116. # Tie weights
  117. # self.decoder.weight = self.encoder.weight
  118. # Since W^{x, c} and W^{h, c} are always summed, there
  119. # is no point duplicating their bias offset parameter. Likewise for
  120. # W^{x, h} and W^{h, h}.
  121. self.w_xc = nn.Linear(shared_embed, self.shared_hid)
  122. self.w_xh = nn.Linear(shared_embed, self.shared_hid)
  123. # The raw weights are stored here because the hidden-to-hidden weights
  124. # are weight dropped on the forward pass.
  125. self.w_hc_raw = torch.nn.Parameter(
  126. torch.Tensor(self.shared_hid, self.shared_hid))
  127. self.w_hh_raw = torch.nn.Parameter(
  128. torch.Tensor(self.shared_hid, self.shared_hid))
  129. self.w_hc = None
  130. self.w_hh = None
  131. self.w_h = collections.defaultdict(dict)
  132. self.w_c = collections.defaultdict(dict)
  133. for idx in range(self.num_blocks):
  134. for jdx in range(idx + 1, self.num_blocks):
  135. self.w_h[idx][jdx] = nn.Linear(self.shared_hid,
  136. self.shared_hid,
  137. bias=False)
  138. self.w_c[idx][jdx] = nn.Linear(self.shared_hid,
  139. self.shared_hid,
  140. bias=False)
  141. self._w_h = nn.ModuleList([self.w_h[idx][jdx]
  142. for idx in self.w_h
  143. for jdx in self.w_h[idx]])
  144. self._w_c = nn.ModuleList([self.w_c[idx][jdx]
  145. for idx in self.w_c
  146. for jdx in self.w_c[idx]])
  147. self.batch_norm = None
  148. # if args.mode == 'train':
  149. # self.batch_norm = nn.BatchNorm1d(self.shared_hid)
  150. # else:
  151. # self.batch_norm = None
  152. self.reset_parameters()
  153. self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
  154. def setDAG(self, dag):
  155. if self.dag is None:
  156. self.dag = dag
  157. def forward(self, word_seq, hidden=None):
  158. inputs = torch.transpose(word_seq, 0, 1)
  159. time_steps = inputs.size(0)
  160. batch_size = inputs.size(1)
  161. self.w_hh = _get_dropped_weights(self.w_hh_raw,
  162. 0.5,
  163. self.training)
  164. self.w_hc = _get_dropped_weights(self.w_hc_raw,
  165. 0.5,
  166. self.training)
  167. # hidden = self.static_init_hidden[batch_size] if hidden is None else hidden
  168. hidden = self.static_init_hidden[batch_size]
  169. embed = self.encoder(inputs)
  170. embed = self.lockdrop(embed, 0.65 if self.training else 0)
  171. # The norm of hidden states are clipped here because
  172. # otherwise ENAS is especially prone to exploding activations on the
  173. # forward pass. This could probably be fixed in a more elegant way, but
  174. # it might be exposing a weakness in the ENAS algorithm as currently
  175. # proposed.
  176. #
  177. # For more details, see
  178. # https://github.com/carpedm20/ENAS-pytorch/issues/6
  179. clipped_num = 0
  180. max_clipped_norm = 0
  181. h1tohT = []
  182. logits = []
  183. for step in range(time_steps):
  184. x_t = embed[step]
  185. logit, hidden = self.cell(x_t, hidden, self.dag)
  186. hidden_norms = hidden.norm(dim=-1)
  187. max_norm = 25.0
  188. if hidden_norms.data.max() > max_norm:
  189. # Just directly use the torch slice operations
  190. # in PyTorch v0.4.
  191. #
  192. # This workaround for PyTorch v0.3.1 does everything in numpy,
  193. # because the PyTorch slicing and slice assignment is too
  194. # flaky.
  195. hidden_norms = hidden_norms.data.cpu().numpy()
  196. clipped_num += 1
  197. if hidden_norms.max() > max_clipped_norm:
  198. max_clipped_norm = hidden_norms.max()
  199. clip_select = hidden_norms > max_norm
  200. clip_norms = hidden_norms[clip_select]
  201. mask = np.ones(hidden.size())
  202. normalizer = max_norm/clip_norms
  203. normalizer = normalizer[:, np.newaxis]
  204. mask[clip_select] = normalizer
  205. if self.use_cuda:
  206. hidden *= torch.autograd.Variable(
  207. torch.FloatTensor(mask).cuda(), requires_grad=False)
  208. else:
  209. hidden *= torch.autograd.Variable(
  210. torch.FloatTensor(mask), requires_grad=False)
  211. logits.append(logit)
  212. h1tohT.append(hidden)
  213. h1tohT = torch.stack(h1tohT)
  214. output = torch.stack(logits)
  215. raw_output = output
  216. output = self.lockdrop(output, 0.4 if self.training else 0)
  217. #Pooling
  218. output = torch.mean(output, 0)
  219. decoded = self.decoder(output)
  220. extra_out = {'dropped': decoded,
  221. 'hiddens': h1tohT,
  222. 'raw': raw_output}
  223. return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out}
  224. def cell(self, x, h_prev, dag):
  225. """Computes a single pass through the discovered RNN cell."""
  226. c = {}
  227. h = {}
  228. f = {}
  229. f[0] = self.get_f(dag[-1][0].name)
  230. c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
  231. h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
  232. (1 - c[0])*h_prev)
  233. leaf_node_ids = []
  234. q = collections.deque()
  235. q.append(0)
  236. # Computes connections from the parent nodes `node_id`
  237. # to their child nodes `next_id` recursively, skipping leaf nodes. A
  238. # leaf node is a node whose id == `self.num_blocks`.
  239. #
  240. # Connections between parent i and child j should be computed as
  241. # h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i,
  242. # where c_j = \sigmoid{(W^c_{ij}*h_i)}
  243. #
  244. # See Training details from Section 3.1 of the paper.
  245. #
  246. # The following algorithm does a breadth-first (since `q.popleft()` is
  247. # used) search over the nodes and computes all the hidden states.
  248. while True:
  249. if len(q) == 0:
  250. break
  251. node_id = q.popleft()
  252. nodes = dag[node_id]
  253. for next_node in nodes:
  254. next_id = next_node.id
  255. if next_id == self.num_blocks:
  256. leaf_node_ids.append(node_id)
  257. assert len(nodes) == 1, ('parent of leaf node should have '
  258. 'only one child')
  259. continue
  260. w_h = self.w_h[node_id][next_id]
  261. w_c = self.w_c[node_id][next_id]
  262. f[next_id] = self.get_f(next_node.name)
  263. c[next_id] = torch.sigmoid(w_c(h[node_id]))
  264. h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
  265. (1 - c[next_id])*h[node_id])
  266. q.append(next_id)
  267. # Instead of averaging loose ends, perhaps there should
  268. # be a set of separate unshared weights for each "loose" connection
  269. # between each node in a cell and the output.
  270. #
  271. # As it stands, all weights W^h_{ij} are doing double duty by
  272. # connecting both from i to j, as well as from i to the output.
  273. # average all the loose ends
  274. leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
  275. output = torch.mean(torch.stack(leaf_nodes, 2), -1)
  276. # stabilizing the Updates of omega
  277. if self.batch_norm is not None:
  278. output = self.batch_norm(output)
  279. return output, h[self.num_blocks - 1]
  280. def init_hidden(self, batch_size):
  281. zeros = torch.zeros(batch_size, self.shared_hid)
  282. return utils.get_variable(zeros, self.use_cuda, requires_grad=False)
  283. def get_f(self, name):
  284. name = name.lower()
  285. if name == 'relu':
  286. f = torch.relu
  287. elif name == 'tanh':
  288. f = torch.tanh
  289. elif name == 'identity':
  290. f = lambda x: x
  291. elif name == 'sigmoid':
  292. f = torch.sigmoid
  293. return f
  294. @property
  295. def num_parameters(self):
  296. def size(p):
  297. return np.prod(p.size())
  298. return sum([size(param) for param in self.parameters()])
  299. def reset_parameters(self):
  300. init_range = 0.025
  301. # init_range = 0.025 if self.args.mode == 'train' else 0.04
  302. for param in self.parameters():
  303. param.data.uniform_(-init_range, init_range)
  304. self.decoder.bias.data.fill_(0)
  305. def predict(self, word_seq):
  306. """
  307. :param word_seq: torch.LongTensor, [batch_size, seq_len]
  308. :return predict: dict of torch.LongTensor, [batch_size, seq_len]
  309. """
  310. output = self(word_seq)
  311. _, predict = output['pred'].max(dim=1)
  312. return {'pred': predict}