# codes in this file are reproduced from https://github.com/microsoft/nni with some changes. import torch import torch.nn as nn import torch.nn.functional as F from . import register_nas_algo from .base import BaseNAS from ..space import BaseSpace from ..utils import ( AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module, PathSamplingInputChoice, PathSamplingLayerChoice, ) from nni.nas.pytorch.fixed import apply_fixed_architecture from tqdm import tqdm from datetime import datetime import numpy as np from ....utils import get_logger LOGGER = get_logger("RL_NAS") def _get_mask(sampled, total): multihot = [ i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total) ] return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable class StackedLSTMCell(nn.Module): def __init__(self, layers, size, bias): super().__init__() self.lstm_num_layers = layers self.lstm_modules = nn.ModuleList( [nn.LSTMCell(size, size, bias=bias) for _ in range(self.lstm_num_layers)] ) def forward(self, inputs, hidden): prev_h, prev_c = hidden next_h, next_c = [], [] for i, m in enumerate(self.lstm_modules): curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i])) next_c.append(curr_c) next_h.append(curr_h) # current implementation only supports batch size equals 1, # but the algorithm does not necessarily have this limitation inputs = curr_h[-1].view(1, -1) return next_h, next_c class ReinforceField: """ A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be selected. Otherwise, any number of choices can be chosen. """ def __init__(self, name, total, choose_one): self.name = name self.total = total self.choose_one = choose_one def __repr__(self): return f"ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})" class ReinforceController(nn.Module): """ A controller that mutates the graph with RL. Parameters ---------- fields : list of ReinforceField List of fields to choose. lstm_size : int Controller LSTM hidden units. lstm_num_layers : int Number of layers for stacked LSTM. tanh_constant : float Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. skip_target : float Target probability that skipconnect will appear. temperature : float Temperature constant that divides the logits. entropy_reduction : str Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced. """ def __init__( self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, skip_target=0.4, temperature=None, entropy_reduction="sum", ): super(ReinforceController, self).__init__() self.fields = fields self.lstm_size = lstm_size self.lstm_num_layers = lstm_num_layers self.tanh_constant = tanh_constant self.temperature = temperature self.skip_target = skip_target self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.skip_targets = nn.Parameter( torch.tensor( [1.0 - self.skip_target, self.skip_target] ), # pylint: disable=not-callable requires_grad=False, ) assert entropy_reduction in [ "sum", "mean", ], "Entropy reduction must be one of sum and mean." self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.soft = nn.ModuleDict( { field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields } ) self.embedding = nn.ModuleDict( {field.name: nn.Embedding(field.total, self.lstm_size) for field in fields} ) def resample(self): self._initialize() result = dict() for field in self.fields: result[field.name] = self._sample_single(field) return result def _initialize(self): self._inputs = self.g_emb.data self._c = [ torch.zeros( (1, self.lstm_size), dtype=self._inputs.dtype, device=self._inputs.device, ) for _ in range(self.lstm_num_layers) ] self._h = [ torch.zeros( (1, self.lstm_size), dtype=self._inputs.dtype, device=self._inputs.device, ) for _ in range(self.lstm_num_layers) ] self.sample_log_prob = 0 self.sample_entropy = 0 self.sample_skip_penalty = 0 def _lstm_next_step(self): self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) def _sample_single(self, field): self._lstm_next_step() logit = self.soft[field.name](self._h[-1]) if self.temperature is not None: logit /= self.temperature if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) if field.choose_one: sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, sampled) self._inputs = self.embedding[field.name](sampled) else: logit = logit.view(-1, 1) logit = torch.cat( [-logit, logit], 1 ) # pylint: disable=invalid-unary-operand-type sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) self.sample_skip_penalty += kl log_prob = self.cross_entropy_loss(logit, sampled) sampled = sampled.nonzero().view(-1) if sampled.sum().item(): self._inputs = ( torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1.0 + torch.sum(sampled)) ).unsqueeze(0) else: self._inputs = torch.zeros( 1, self.lstm_size, device=self.embedding[field.name].weight.device ) sampled = sampled.detach().numpy().tolist() self.sample_log_prob += self.entropy_reduction(log_prob) entropy = ( log_prob * torch.exp(-log_prob) ).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += self.entropy_reduction(entropy) if len(sampled) == 1: sampled = sampled[0] return sampled @register_nas_algo("rl") class RL(BaseNAS): """ RL in GraphNas. Parameters ---------- num_epochs : int Number of epochs planned for training. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. log_frequency : int Step count per logging. grad_clip : float Gradient clipping. Set to 0 to disable. Default: 5. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. ctrl_lr : float Learning rate for RL controller. ctrl_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. ctrl_steps : int Number of mini-batches for each epoch of RL controller learning. ctrl_kwargs : dict Optional kwargs that will be passed to :class:`ReinforceController`. n_warmup : int Number of epochs for training super network. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. disable_progress: boolean Control whether show the progress bar. """ def __init__( self, num_epochs=5, device="auto", log_frequency=None, grad_clip=5.0, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None, n_warmup=100, model_lr=5e-3, model_wd=5e-4, disable_progress=False, ): super().__init__(device) self.num_epochs = num_epochs self.log_frequency = log_frequency self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0.0 self.ctrl_steps_aggregate = ctrl_steps_aggregate self.grad_clip = grad_clip self.ctrl_kwargs = ctrl_kwargs self.ctrl_lr = ctrl_lr self.n_warmup = n_warmup self.model_lr = model_lr self.model_wd = model_wd self.disable_progress = disable_progress def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset # .to(self.device) self.estimator = estimator # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) # fields self.nas_fields = [ ReinforceField( name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1, ) for name, module in self.nas_modules ] self.controller = ReinforceController( self.nas_fields, **(self.ctrl_kwargs or {}) ) self.ctrl_optim = torch.optim.Adam( self.controller.parameters(), lr=self.ctrl_lr ) # train with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar: for i in bar: l2 = self._train_controller(i) bar.set_postfix(reward_controller=l2) selection = self.export() arch = space.parse_model(selection, self.device) # print(selection,arch) return arch def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() rewards = [] with tqdm( range(self.ctrl_steps_aggregate), disable=self.disable_progress ) as bar: for ctrl_step in bar: self._resample() metric, loss = self._infer(mask="val") reward = metric bar.set_postfix(acc=metric, loss=loss.item()) LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}") rewards.append(reward) if self.entropy_weight: reward += ( self.entropy_weight * self.controller.sample_entropy.item() ) self.baseline = self.baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay ) loss = self.controller.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.controller.sample_skip_penalty loss /= self.ctrl_steps_aggregate loss.backward() if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0: if self.grad_clip > 0: nn.utils.clip_grad_norm_( self.controller.parameters(), self.grad_clip ) self.ctrl_optim.step() self.ctrl_optim.zero_grad() if ( self.log_frequency is not None and ctrl_step % self.log_frequency == 0 ): LOGGER.debug( "RL Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, ctrl_step + 1, self.ctrl_steps_aggregate, ) return sum(rewards) / len(rewards) def _resample(self): result = self.controller.resample() self.arch = self.model.parse_model(result, device=self.device) self.selection = result def export(self): self.controller.eval() with torch.no_grad(): return self.controller.resample() def _infer(self, mask="train"): metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask) return metric[0], loss @register_nas_algo("graphnas") class GraphNasRL(BaseNAS): """ RL in GraphNas. Parameters ---------- device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. num_epochs : int Number of epochs planned for training. log_frequency : int Step count per logging. grad_clip : float Gradient clipping. Set to 0 to disable. Default: 5. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. ctrl_lr : float Learning rate for RL controller. ctrl_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. ctrl_steps : int Number of mini-batches for each epoch of RL controller learning. ctrl_kwargs : dict Optional kwargs that will be passed to :class:`ReinforceController`. n_warmup : int Number of epochs for training super network. model_lr : float Learning rate for super network. model_wd : float Weight decay for super network. topk : int Number of architectures kept in training process. disable_progeress: boolean Control whether show the progress bar. """ def __init__( self, device="auto", num_epochs=10, log_frequency=None, grad_clip=5.0, entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95, ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None, n_warmup=100, model_lr=5e-3, model_wd=5e-4, topk=5, disable_progress=False, hardware_metric_limit=None, ): super().__init__(device) self.num_epochs = num_epochs self.log_frequency = log_frequency self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.ctrl_steps_aggregate = ctrl_steps_aggregate self.grad_clip = grad_clip self.ctrl_kwargs = ctrl_kwargs self.ctrl_lr = ctrl_lr self.n_warmup = n_warmup self.model_lr = model_lr self.model_wd = model_wd self.hist = [] self.topk = topk self.disable_progress = disable_progress self.hardware_metric_limit = hardware_metric_limit self.allhist=[] def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset # .to(self.device) self.estimator = estimator # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) # fields self.nas_fields = [ ReinforceField( name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1, ) for name, module in self.nas_modules ] self.controller = ReinforceController( self.nas_fields, lstm_size=100, temperature=5.0, tanh_constant=2.5, **(self.ctrl_kwargs or {}), ) self.ctrl_optim = torch.optim.Adam( self.controller.parameters(), lr=self.ctrl_lr ) # train with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar: for i in bar: l2 = self._train_controller(i) bar.set_postfix(reward_controller=l2) # selection=self.export() selections = [x[1] for x in self.hist] candidiate_accs = [-x[0] for x in self.hist] # print('candidiate accuracies',candidiate_accs) selection = self._choose_best(selections) arch = space.parse_model(selection, self.device) # print(selection,arch) return arch def _choose_best(self, selections): # graphnas use top 5 models, can evaluate 20 times epoch and choose the best. results = [] for selection in selections: accs = [] for i in tqdm(range(20), disable=self.disable_progress): self.arch = self.model.parse_model(selection, device=self.device) metric, loss, _ = self._infer(mask="val") accs.append(metric) result = np.mean(accs) LOGGER.info( "selection {} \n acc {:.4f} +- {:.4f}".format( selection, np.mean(accs), np.std(accs) / np.sqrt(20) ) ) results.append(result) best_selection = selections[np.argmax(results)] return best_selection def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() rewards = [] baseline = None # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000) with tqdm( range(self.ctrl_steps_aggregate), disable=self.disable_progress ) as bar: for ctrl_step in bar: self._resample() metric, loss, hardware_metric = self._infer(mask="val") reward = metric # bar.set_postfix(acc=metric,loss=loss.item()) LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}") # diff: not do reward shaping as in graphnas code if ( self.hardware_metric_limit is None or hardware_metric[0] < self.hardware_metric_limit ): self.hist.append([-metric, self.selection]) self.allhist.append([-metric, self.selection]) if len(self.hist) > self.topk: self.hist.sort(key=lambda x: x[0]) self.hist.pop() rewards.append(reward) if self.entropy_weight: reward += ( self.entropy_weight * self.controller.sample_entropy.item() ) if not baseline: baseline = reward else: baseline = baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay ) loss = self.controller.sample_log_prob * (reward - baseline) self.ctrl_optim.zero_grad() loss.backward() self.ctrl_optim.step() bar.set_postfix(acc=metric, max_acc=max(rewards)) LOGGER.info("epoch:{}, mean rewards:{}".format(epoch, sum(rewards) / len(rewards))) return sum(rewards) / len(rewards) def _resample(self): result = self.controller.resample() self.arch = self.model.parse_model(result, device=self.device) self.selection = result def export(self): self.controller.eval() with torch.no_grad(): return self.controller.resample() def _infer(self, mask="train"): metric, loss = self.estimator.infer(self.arch._model, self.dataset, mask=mask) return metric[0], loss, metric[1:]