Browse Source

Merge branch 'nas' of git:THUMNLab/AutoGL into nas

tags/v0.3.1
Frozenmad 5 years ago
parent
commit
2b343dca8a
5 changed files with 406 additions and 4 deletions
  1. +2
    -1
      autogl/module/nas/algorithm/__init__.py
  2. +369
    -0
      autogl/module/nas/algorithm/enas.py
  3. +5
    -3
      autogl/module/nas/space/single_path.py
  4. +29
    -0
      examples/test_enas.py
  5. +1
    -0
      examples/test_nas.py

+ 2
- 1
autogl/module/nas/algorithm/__init__.py View File

@@ -4,5 +4,6 @@ NAS algorithms

from .base import BaseNAS
from .darts import Darts
from .enas import Enas

__all__ = ["BaseNAS", "Darts"]
__all__ = ["BaseNAS", "Darts", "Enas"]

+ 369
- 0
autogl/module/nas/algorithm/enas.py View File

@@ -0,0 +1,369 @@
# codes in this file are reproduced from https://github.com/microsoft/nni with some changes.
import copy
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseNAS
from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
from nni.nas.pytorch.fixed import apply_fixed_architecture
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable

class PathSamplingLayerChoice(nn.Module):
"""
Mixed module, in which fprop is decided by exactly one or multiple (sampled) module.
If multiple module is selected, the result will be sumed and returned.

Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""

def __init__(self, layer_choice):
super(PathSamplingLayerChoice, self).__init__()
self.op_names = []
for name, module in layer_choice.named_children():
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.sampled = None # sampled can be either a list of indices or an index

def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable
else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index

def __len__(self):
return len(self.op_names)

@property
def mask(self):
return _get_mask(self.sampled, len(self))


class PathSamplingInputChoice(nn.Module):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.

Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""

def __init__(self, input_choice):
super(PathSamplingInputChoice, self).__init__()
self.n_candidates = input_choice.n_candidates
self.n_chosen = input_choice.n_chosen
self.sampled = None

def forward(self, input_tensors):
if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable
else:
return input_tensors[self.sampled]

def __len__(self):
return self.n_candidates

@property
def mask(self):
return _get_mask(self.sampled, len(self))


class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])

def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c


class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""

def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one

def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'


class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.

Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""

def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target

self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})

def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result

def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0

def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))

def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)

sampled = sampled.detach().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled


class Enas(BaseNAS):
"""
ENAS trainer.

Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""

def __init__(self, device='cuda', workers=4,log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None,*args,**kwargs):
super().__init__(*args,**kwargs)
self.device=device
self.num_epochs = kwargs.get("num_epochs", 5)
self.workers = workers
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
self.workers = workers
self.ctrl_kwargs=ctrl_kwargs
self.ctrl_lr=ctrl_lr

def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset#.to(self.device)
self.estimator = estimator
self.model_optim = torch.optim.SGD(
self.model.parameters(), lr=0.01, weight_decay=3e-4
)
# replace choice
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
# to device
self.model = self.model.to(self.device)
# fields
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {}))
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr)
# train
for i in range(self.num_epochs):
self._train_model(i)
self._train_controller(i)
selection=self.export()
return space.export(selection,self.device)

def _train_model(self, epoch):
self.model.train()
self.controller.eval()
self.model_optim.zero_grad()
self._resample()
metric,loss=self._infer()
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.model_optim.step()

def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
for ctrl_step in range(self.ctrl_steps_aggregate):
self._resample()
with torch.no_grad():
metric,loss=self._infer()
reward =-metric # todo : now metric is loss
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
loss /= self.ctrl_steps_aggregate
loss.backward()
if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip)
self.ctrl_optim.step()
self.ctrl_optim.zero_grad()

if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs,
ctrl_step + 1, self.ctrl_steps_aggregate)

def _resample(self):
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]

def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()

def _infer(self):
metric, loss = self.estimator.infer(self.model, self.dataset)
return metric, loss

+ 5
- 3
autogl/module/nas/space/single_path.py View File

@@ -9,6 +9,7 @@ from .base import BaseSpace
from ...model import BaseModel
from ....utils import get_logger

from ...model import AutoGCN

class FixedNodeClassificationModel(BaseModel):
_logger = get_logger("space model")
@@ -56,7 +57,7 @@ class SinglePathNodeClassificationSpace(BaseSpace):
self,
hidden_dim: _typ.Optional[int] = 64,
layer_number: _typ.Optional[int] = 2,
dropout: _typ.Optional[float] = 0.6,
dropout: _typ.Optional[float] = 0.2,
input_dim: _typ.Optional[int] = None,
output_dim: _typ.Optional[int] = None,
ops: _typ.Tuple = None,
@@ -109,9 +110,10 @@ class SinglePathNodeClassificationSpace(BaseSpace):
for layer in range(self.layer_number):
x = getattr(self, f"op_{layer}")(x, edges)
if layer != self.layer_number - 1:
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.leaky_relu(x)
x = F.dropout(x, p=self.dropout, training = self.training)
return F.log_softmax(x, dim=1)

def export(self, selection, device) -> BaseModel:
#return AutoGCN(self.input_dim, self.output_dim, device)
return FixedNodeClassificationModel(self, selection, device)

+ 29
- 0
examples/test_enas.py View File

@@ -0,0 +1,29 @@
from copy import deepcopy
import sys
from nni.nas.pytorch.fixed import apply_fixed_architecture
from torch_geometric.nn.conv.gat_conv import GATConv
from torch_geometric.nn.conv.gcn_conv import GCNConv
sys.path.append('../')
import torch
from autogl.solver import AutoNodeClassifier
from autogl.module.nas.nas import DartsNodeClfEstimator
from autogl.module.nas.space import GraphSpace
from autogl.datasets import build_dataset_from_name
from autogl.module.model import BaseModel
# from autogl.module.nas.darts import Darts
from autogl.utils import get_logger
from autogl.module.nas.enas import Enas
if __name__ == '__main__':
dataset = build_dataset_from_name('cora')
solver = AutoNodeClassifier(
feature_module=None,
graph_models=[],
hpo_module="random",
max_evals=10,
ensemble_module=None,
nas_algorithms=[Enas()],
nas_spaces=[GraphSpace(hidden_dim=64, ops=[GATConv, GCNConv])],
nas_estimators=[DartsNodeClfEstimator()]
)
solver.fit(dataset)
out = solver.predict(dataset)

+ 1
- 0
examples/test_nas.py View File

@@ -7,6 +7,7 @@ from autogl.solver import AutoNodeClassifier
from autogl.module.train import NodeClassificationFullTrainer
from autogl.module.nas import Darts, OneShotEstimator, SinglePathNodeClassificationSpace
from autogl.module.train import Acc
from autogl.module.nas.algorithm.enas import Enas

if __name__ == '__main__':
dataset = build_dataset_from_name('cora')


Loading…
Cancel
Save