Browse Source

ordered fields in space

tags/v0.3.1
cluster32 5 years ago
parent
commit
e3a11d5d5e
6 changed files with 72 additions and 65 deletions
  1. +1
    -54
      autogl/module/nas/algorithm/darts.py
  2. +5
    -1
      autogl/module/nas/algorithm/enas.py
  3. +44
    -3
      autogl/module/nas/space/base.py
  4. +3
    -3
      autogl/module/nas/space/single_path.py
  5. +16
    -1
      autogl/module/nas/utils.py
  6. +3
    -3
      examples/test_nas.py

+ 1
- 54
autogl/module/nas/algorithm/darts.py View File

@@ -12,64 +12,11 @@ from ..estimator.base import BaseEstimator
from ..space import BaseSpace
from ..utils import replace_layer_choice, replace_input_choice
from ...model.base import BaseModel
from nni.retiarii.oneshot.pytorch.darts import DartsLayerChoice, DartsInputChoice

_logger = logging.getLogger(__name__)


class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)

def forward(self, *args, **kwargs):
op_results = torch.stack(
[op(*args, **kwargs) for op in self.op_choices.values()]
)
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)

def parameters(self):
for _, p in self.named_parameters():
yield p

def named_parameters(self):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == "alpha":
continue
yield name, p

def export(self):
return torch.argmax(self.alpha).item()


class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1

def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)

def parameters(self):
for _, p in self.named_parameters():
yield p

def named_parameters(self):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == "alpha":
continue
yield name, p

def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[: self.n_chosen]


class Darts(BaseNAS):
"""
DARTS trainer.


+ 5
- 1
autogl/module/nas/algorithm/enas.py View File

@@ -8,7 +8,7 @@ import torch.nn.functional as F

from .base import BaseNAS
from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module
from nni.nas.pytorch.fixed import apply_fixed_architecture
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
@@ -297,8 +297,12 @@ class Enas(BaseNAS):
)
# replace choice
self.nas_modules = []

k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)

# to device
self.model = self.model.to(self.device)
# fields


+ 44
- 3
autogl/module/nas/space/base.py View File

@@ -1,7 +1,21 @@
from abc import abstractmethod
from autogl.module.model import BaseModel
import torch.nn as nn
from nni.nas.pytorch import mutables

class OrderedMutable():
def __init__(self, order):
self.order = order

class OrderedLayerChoice(OrderedMutable, mutables.LayerChoice):
def __init__(self, order, *args, **kwargs):
OrderedMutable.__init__(self, order)
mutables.LayerChoice.__init__(self, *args, **kwargs)

class OrderedInputChoice(OrderedMutable, mutables.InputChoice):
def __init__(self, order, *args, **kwargs):
OrderedMutable.__init__(self, order)
mutables.InputChoice.__init__(self, *args, **kwargs)

class BaseSpace(nn.Module):
"""
@@ -21,12 +35,11 @@ class BaseSpace(nn.Module):
self._initialized = False

@abstractmethod
def instantiate(self):
def _instantiate(self):
"""
Instantiate modules in the space
"""
if not self._initialized:
self._initialized = True
raise NotImplementedError()

@abstractmethod
def forward(self, *args, **kwargs):
@@ -53,3 +66,31 @@ class BaseSpace(nn.Module):
model to be exported.
"""
raise NotImplementedError()

def instantiate(self, *args, **kwargs):
self._default_key = 0
self._instantiate(*args, **kwargs)
if not self._initialized:
self._initialized = True

def setLayerChoice(self, *args, **kwargs):
"""
Give a unique key if not given
"""
if len(args) < 5 and not "key" in kwargs:
key = f"default_key_{self._default_key}"
self._default_key += 1
kwargs["key"] = key
layer = OrderedLayerChoice(*args, **kwargs)
return layer

def setInputChoice(self, *args, **kwargs):
"""
Give a unique key if not given
"""
if len(args) < 7 and not "key" in kwargs:
key = f"default_key_{self._default_key}"
self._default_key += 1
kwargs["key"] = key
layer = OrderedInputChoice(*args, **kwargs)
return layer

+ 3
- 3
autogl/module/nas/space/single_path.py View File

@@ -71,7 +71,7 @@ class SinglePathNodeClassificationSpace(BaseSpace):
self.ops = ops
self.dropout = dropout

def instantiate(
def _instantiate(
self,
hidden_dim: _typ.Optional[int] = None,
layer_number: _typ.Optional[int] = None,
@@ -89,7 +89,8 @@ class SinglePathNodeClassificationSpace(BaseSpace):
setattr(
self,
f"op_{layer}",
mutables.LayerChoice(
self.setLayerChoice(
layer,
[
op(
self.input_dim if layer == 0 else self.hidden_dim,
@@ -99,7 +100,6 @@ class SinglePathNodeClassificationSpace(BaseSpace):
)
for op in self.ops
],
key=f"{layer}",
),
)
self._initialized = True


+ 16
- 1
autogl/module/nas/utils.py View File

@@ -7,7 +7,7 @@ from collections import OrderedDict
import numpy as np
import torch
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
from nni.nas.pytorch.mutables import Mutable, InputChoice, LayerChoice

_logger = logging.getLogger(__name__)

@@ -123,6 +123,21 @@ class AverageMeter:
fmtstr = "{name}: {avg" + self.fmt + "}"
return fmtstr.format(**self.__dict__)

def get_module_order(root_module):
key2order = {}
def apply(m):
for name, child in m.named_children():
if isinstance(child, Mutable):
key2order[child.key] = child.order
else:
apply(child)

apply(root_module)
return key2order

def sort_replaced_module(k2o, modules):
modules = sorted(modules, key = lambda x:k2o[x[0]])
return modules

def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None:


+ 3
- 3
examples/test_nas.py View File

@@ -31,9 +31,9 @@ if __name__ == '__main__':
feval=['acc'],
loss="nll_loss",
lr_scheduler_type=None,),
#nas_algorithms=[Enas()],
nas_algorithms=[Darts(num_epochs=1)],
nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GATConv])],
nas_algorithms=[Enas()],
#nas_algorithms=[Darts(num_epochs=1)],
nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv])],
nas_estimators=[OneShotEstimator()]
)
solver.fit(dataset)


Loading…
Cancel
Save