diff --git a/autogl/module/nas/algorithm/darts.py b/autogl/module/nas/algorithm/darts.py index 00ac3a5..c7510c7 100644 --- a/autogl/module/nas/algorithm/darts.py +++ b/autogl/module/nas/algorithm/darts.py @@ -12,64 +12,11 @@ from ..estimator.base import BaseEstimator from ..space import BaseSpace from ..utils import replace_layer_choice, replace_input_choice from ...model.base import BaseModel +from nni.retiarii.oneshot.pytorch.darts import DartsLayerChoice, DartsInputChoice _logger = logging.getLogger(__name__) -class DartsLayerChoice(nn.Module): - def __init__(self, layer_choice): - super(DartsLayerChoice, self).__init__() - self.name = layer_choice.key - self.op_choices = nn.ModuleDict(layer_choice.named_children()) - self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) - - def forward(self, *args, **kwargs): - op_results = torch.stack( - [op(*args, **kwargs) for op in self.op_choices.values()] - ) - alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) - return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0) - - def parameters(self): - for _, p in self.named_parameters(): - yield p - - def named_parameters(self): - for name, p in super(DartsLayerChoice, self).named_parameters(): - if name == "alpha": - continue - yield name, p - - def export(self): - return torch.argmax(self.alpha).item() - - -class DartsInputChoice(nn.Module): - def __init__(self, input_choice): - super(DartsInputChoice, self).__init__() - self.name = input_choice.key - self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) - self.n_chosen = input_choice.n_chosen or 1 - - def forward(self, inputs): - inputs = torch.stack(inputs) - alpha_shape = [-1] + [1] * (len(inputs.size()) - 1) - return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0) - - def parameters(self): - for _, p in self.named_parameters(): - yield p - - def named_parameters(self): - for name, p in super(DartsInputChoice, self).named_parameters(): - if name == "alpha": - continue - yield name, p - - def export(self): - return torch.argsort(-self.alpha).cpu().numpy().tolist()[: self.n_chosen] - - class Darts(BaseNAS): """ DARTS trainer. diff --git a/autogl/module/nas/algorithm/enas.py b/autogl/module/nas/algorithm/enas.py index 3d175aa..57ca42b 100644 --- a/autogl/module/nas/algorithm/enas.py +++ b/autogl/module/nas/algorithm/enas.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .base import BaseNAS from ..space import BaseSpace -from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice +from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module from nni.nas.pytorch.fixed import apply_fixed_architecture _logger = logging.getLogger(__name__) def _get_mask(sampled, total): @@ -297,8 +297,12 @@ class Enas(BaseNAS): ) # replace choice self.nas_modules = [] + + k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) + self.nas_modules = sort_replaced_module(k2o, self.nas_modules) + # to device self.model = self.model.to(self.device) # fields diff --git a/autogl/module/nas/space/base.py b/autogl/module/nas/space/base.py index 1022dce..4c38584 100644 --- a/autogl/module/nas/space/base.py +++ b/autogl/module/nas/space/base.py @@ -1,7 +1,21 @@ from abc import abstractmethod from autogl.module.model import BaseModel import torch.nn as nn +from nni.nas.pytorch import mutables +class OrderedMutable(): + def __init__(self, order): + self.order = order + +class OrderedLayerChoice(OrderedMutable, mutables.LayerChoice): + def __init__(self, order, *args, **kwargs): + OrderedMutable.__init__(self, order) + mutables.LayerChoice.__init__(self, *args, **kwargs) + +class OrderedInputChoice(OrderedMutable, mutables.InputChoice): + def __init__(self, order, *args, **kwargs): + OrderedMutable.__init__(self, order) + mutables.InputChoice.__init__(self, *args, **kwargs) class BaseSpace(nn.Module): """ @@ -21,12 +35,11 @@ class BaseSpace(nn.Module): self._initialized = False @abstractmethod - def instantiate(self): + def _instantiate(self): """ Instantiate modules in the space """ - if not self._initialized: - self._initialized = True + raise NotImplementedError() @abstractmethod def forward(self, *args, **kwargs): @@ -53,3 +66,31 @@ class BaseSpace(nn.Module): model to be exported. """ raise NotImplementedError() + + def instantiate(self, *args, **kwargs): + self._default_key = 0 + self._instantiate(*args, **kwargs) + if not self._initialized: + self._initialized = True + + def setLayerChoice(self, *args, **kwargs): + """ + Give a unique key if not given + """ + if len(args) < 5 and not "key" in kwargs: + key = f"default_key_{self._default_key}" + self._default_key += 1 + kwargs["key"] = key + layer = OrderedLayerChoice(*args, **kwargs) + return layer + + def setInputChoice(self, *args, **kwargs): + """ + Give a unique key if not given + """ + if len(args) < 7 and not "key" in kwargs: + key = f"default_key_{self._default_key}" + self._default_key += 1 + kwargs["key"] = key + layer = OrderedInputChoice(*args, **kwargs) + return layer diff --git a/autogl/module/nas/space/single_path.py b/autogl/module/nas/space/single_path.py index 8bbe8b3..c1a2a60 100644 --- a/autogl/module/nas/space/single_path.py +++ b/autogl/module/nas/space/single_path.py @@ -71,7 +71,7 @@ class SinglePathNodeClassificationSpace(BaseSpace): self.ops = ops self.dropout = dropout - def instantiate( + def _instantiate( self, hidden_dim: _typ.Optional[int] = None, layer_number: _typ.Optional[int] = None, @@ -89,7 +89,8 @@ class SinglePathNodeClassificationSpace(BaseSpace): setattr( self, f"op_{layer}", - mutables.LayerChoice( + self.setLayerChoice( + layer, [ op( self.input_dim if layer == 0 else self.hidden_dim, @@ -99,7 +100,6 @@ class SinglePathNodeClassificationSpace(BaseSpace): ) for op in self.ops ], - key=f"{layer}", ), ) self._initialized = True diff --git a/autogl/module/nas/utils.py b/autogl/module/nas/utils.py index 4b76d5b..2504cfc 100644 --- a/autogl/module/nas/utils.py +++ b/autogl/module/nas/utils.py @@ -7,7 +7,7 @@ from collections import OrderedDict import numpy as np import torch import nni.retiarii.nn.pytorch as nn -from nni.nas.pytorch.mutables import InputChoice, LayerChoice +from nni.nas.pytorch.mutables import Mutable, InputChoice, LayerChoice _logger = logging.getLogger(__name__) @@ -123,6 +123,21 @@ class AverageMeter: fmtstr = "{name}: {avg" + self.fmt + "}" return fmtstr.format(**self.__dict__) +def get_module_order(root_module): + key2order = {} + def apply(m): + for name, child in m.named_children(): + if isinstance(child, Mutable): + key2order[child.key] = child.order + else: + apply(child) + + apply(root_module) + return key2order + +def sort_replaced_module(k2o, modules): + modules = sorted(modules, key = lambda x:k2o[x[0]]) + return modules def _replace_module_with_type(root_module, init_fn, type_name, modules): if modules is None: diff --git a/examples/test_nas.py b/examples/test_nas.py index 5e053b4..6d27c26 100644 --- a/examples/test_nas.py +++ b/examples/test_nas.py @@ -31,9 +31,9 @@ if __name__ == '__main__': feval=['acc'], loss="nll_loss", lr_scheduler_type=None,), - #nas_algorithms=[Enas()], - nas_algorithms=[Darts(num_epochs=1)], - nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GATConv])], + nas_algorithms=[Enas()], + #nas_algorithms=[Darts(num_epochs=1)], + nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv])], nas_estimators=[OneShotEstimator()] ) solver.fit(dataset)