# Code Modified from https://github.com/carpedm20/ENAS-pytorch """A module with NAS controller-related code.""" import collections import os import torch import torch.nn.functional as F import fastNLP.automl.enas_utils as utils from fastNLP.automl.enas_utils import Node def _construct_dags(prev_nodes, activations, func_names, num_blocks): """Constructs a set of DAGs based on the actions, i.e., previous nodes and activation functions, sampled from the controller/policy pi. Args: prev_nodes: Previous node actions from the policy. activations: Activations sampled from the policy. func_names: Mapping from activation function names to functions. num_blocks: Number of blocks in the target RNN cell. Returns: A list of DAGs defined by the inputs. RNN cell DAGs are represented in the following way: 1. Each element (node) in a DAG is a list of `Node`s. 2. The `Node`s in the list dag[i] correspond to the subsequent nodes that take the output from node i as their own input. 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. dag[-1] always feeds dag[0]. dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its weights. 4. dag[N - 1] is the node that produces the hidden state passed to the next timestep. dag[N - 1] is also always a leaf node, and therefore is always averaged with the other leaf nodes and fed to the output decoder. """ dags = [] for nodes, func_ids in zip(prev_nodes, activations): dag = collections.defaultdict(list) # add first node dag[-1] = [Node(0, func_names[func_ids[0]])] dag[-2] = [Node(0, func_names[func_ids[0]])] # add following nodes for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) leaf_nodes = set(range(num_blocks)) - dag.keys() # merge with avg for idx in leaf_nodes: dag[idx] = [Node(num_blocks, 'avg')] # This is actually y^{(t)}. h^{(t)} is node N - 1 in # the graph, where N Is the number of nodes. I.e., h^{(t)} takes # only one other node as its input. # last h[t] node last_node = Node(num_blocks + 1, 'h[t]') dag[num_blocks] = [last_node] dags.append(dag) return dags class Controller(torch.nn.Module): """Based on https://github.com/pytorch/examples/blob/master/word_language_model/model.py RL controllers do not necessarily have much to do with language models. Base the controller RNN on the GRU from: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py """ def __init__(self, num_blocks=4, controller_hid=100, cuda=False): torch.nn.Module.__init__(self) # `num_tokens` here is just the activation function # for every even step, self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] self.num_tokens = [len(self.shared_rnn_activations)] self.controller_hid = controller_hid self.use_cuda = cuda self.num_blocks = num_blocks for idx in range(num_blocks): self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] self.func_names = self.shared_rnn_activations num_total_tokens = sum(self.num_tokens) self.encoder = torch.nn.Embedding(num_total_tokens, controller_hid) self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) # Perhaps these weights in the decoder should be # shared? At least for the activation functions, which all have the # same size. self.decoders = [] for idx, size in enumerate(self.num_tokens): decoder = torch.nn.Linear(controller_hid, size) self.decoders.append(decoder) self._decoders = torch.nn.ModuleList(self.decoders) self.reset_parameters() self.static_init_hidden = utils.keydefaultdict(self.init_hidden) def _get_default_hidden(key): return utils.get_variable( torch.zeros(key, self.controller_hid), self.use_cuda, requires_grad=False) self.static_inputs = utils.keydefaultdict(_get_default_hidden) def reset_parameters(self): init_range = 0.1 for param in self.parameters(): param.data.uniform_(-init_range, init_range) for decoder in self.decoders: decoder.bias.data.fill_(0) def forward(self, # pylint:disable=arguments-differ inputs, hidden, block_idx, is_embed): if not is_embed: embed = self.encoder(inputs) else: embed = inputs hx, cx = self.lstm(embed, hidden) logits = self.decoders[block_idx](hx) logits /= 5.0 # # exploration # if self.args.mode == 'train': # logits = (2.5 * F.tanh(logits)) return logits, (hx, cx) def sample(self, batch_size=1, with_details=False, save_dir=None): """Samples a set of `args.num_blocks` many computational nodes from the controller, where each node is made up of an activation function, and each node except the last also includes a previous node. """ if batch_size < 1: raise Exception(f'Wrong batch_size: {batch_size} < 1') # [B, L, H] inputs = self.static_inputs[batch_size] hidden = self.static_init_hidden[batch_size] activations = [] entropies = [] log_probs = [] prev_nodes = [] # The RNN controller alternately outputs an activation, # followed by a previous node, for each block except the last one, # which only gets an activation function. The last node is the output # node, and its previous node is the average of all leaf nodes. for block_idx in range(2*(self.num_blocks - 1) + 1): logits, hidden = self.forward(inputs, hidden, block_idx, is_embed=(block_idx == 0)) probs = F.softmax(logits, dim=-1) log_prob = F.log_softmax(logits, dim=-1) # .mean() for entropy? entropy = -(log_prob * probs).sum(1, keepdim=False) action = probs.multinomial(num_samples=1).data selected_log_prob = log_prob.gather( 1, utils.get_variable(action, requires_grad=False)) # why the [:, 0] here? Should it be .squeeze(), or # .view()? Same below with `action`. entropies.append(entropy) log_probs.append(selected_log_prob[:, 0]) # 0: function, 1: previous node mode = block_idx % 2 inputs = utils.get_variable( action[:, 0] + sum(self.num_tokens[:mode]), requires_grad=False) if mode == 0: activations.append(action[:, 0]) elif mode == 1: prev_nodes.append(action[:, 0]) prev_nodes = torch.stack(prev_nodes).transpose(0, 1) activations = torch.stack(activations).transpose(0, 1) dags = _construct_dags(prev_nodes, activations, self.func_names, self.num_blocks) if save_dir is not None: for idx, dag in enumerate(dags): utils.draw_network(dag, os.path.join(save_dir, f'graph{idx}.png')) if with_details: return dags, torch.cat(log_probs), torch.cat(entropies) return dags def init_hidden(self, batch_size): zeros = torch.zeros(batch_size, self.controller_hid) return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))