| @@ -12,64 +12,11 @@ from ..estimator.base import BaseEstimator | |||
| from ..space import BaseSpace | |||
| from ..utils import replace_layer_choice, replace_input_choice | |||
| from ...model.base import BaseModel | |||
| from nni.retiarii.oneshot.pytorch.darts import DartsLayerChoice, DartsInputChoice | |||
| _logger = logging.getLogger(__name__) | |||
| class DartsLayerChoice(nn.Module): | |||
| def __init__(self, layer_choice): | |||
| super(DartsLayerChoice, self).__init__() | |||
| self.name = layer_choice.key | |||
| self.op_choices = nn.ModuleDict(layer_choice.named_children()) | |||
| self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) | |||
| def forward(self, *args, **kwargs): | |||
| op_results = torch.stack( | |||
| [op(*args, **kwargs) for op in self.op_choices.values()] | |||
| ) | |||
| alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) | |||
| return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0) | |||
| def parameters(self): | |||
| for _, p in self.named_parameters(): | |||
| yield p | |||
| def named_parameters(self): | |||
| for name, p in super(DartsLayerChoice, self).named_parameters(): | |||
| if name == "alpha": | |||
| continue | |||
| yield name, p | |||
| def export(self): | |||
| return torch.argmax(self.alpha).item() | |||
| class DartsInputChoice(nn.Module): | |||
| def __init__(self, input_choice): | |||
| super(DartsInputChoice, self).__init__() | |||
| self.name = input_choice.key | |||
| self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) | |||
| self.n_chosen = input_choice.n_chosen or 1 | |||
| def forward(self, inputs): | |||
| inputs = torch.stack(inputs) | |||
| alpha_shape = [-1] + [1] * (len(inputs.size()) - 1) | |||
| return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0) | |||
| def parameters(self): | |||
| for _, p in self.named_parameters(): | |||
| yield p | |||
| def named_parameters(self): | |||
| for name, p in super(DartsInputChoice, self).named_parameters(): | |||
| if name == "alpha": | |||
| continue | |||
| yield name, p | |||
| def export(self): | |||
| return torch.argsort(-self.alpha).cpu().numpy().tolist()[: self.n_chosen] | |||
| class Darts(BaseNAS): | |||
| """ | |||
| DARTS trainer. | |||
| @@ -8,7 +8,7 @@ import torch.nn.functional as F | |||
| from .base import BaseNAS | |||
| from ..space import BaseSpace | |||
| from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice | |||
| from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module | |||
| from nni.nas.pytorch.fixed import apply_fixed_architecture | |||
| _logger = logging.getLogger(__name__) | |||
| def _get_mask(sampled, total): | |||
| @@ -297,8 +297,12 @@ class Enas(BaseNAS): | |||
| ) | |||
| # replace choice | |||
| self.nas_modules = [] | |||
| k2o = get_module_order(self.model) | |||
| replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) | |||
| replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) | |||
| self.nas_modules = sort_replaced_module(k2o, self.nas_modules) | |||
| # to device | |||
| self.model = self.model.to(self.device) | |||
| # fields | |||
| @@ -1,7 +1,21 @@ | |||
| from abc import abstractmethod | |||
| from autogl.module.model import BaseModel | |||
| import torch.nn as nn | |||
| from nni.nas.pytorch import mutables | |||
| class OrderedMutable(): | |||
| def __init__(self, order): | |||
| self.order = order | |||
| class OrderedLayerChoice(OrderedMutable, mutables.LayerChoice): | |||
| def __init__(self, order, *args, **kwargs): | |||
| OrderedMutable.__init__(self, order) | |||
| mutables.LayerChoice.__init__(self, *args, **kwargs) | |||
| class OrderedInputChoice(OrderedMutable, mutables.InputChoice): | |||
| def __init__(self, order, *args, **kwargs): | |||
| OrderedMutable.__init__(self, order) | |||
| mutables.InputChoice.__init__(self, *args, **kwargs) | |||
| class BaseSpace(nn.Module): | |||
| """ | |||
| @@ -21,12 +35,11 @@ class BaseSpace(nn.Module): | |||
| self._initialized = False | |||
| @abstractmethod | |||
| def instantiate(self): | |||
| def _instantiate(self): | |||
| """ | |||
| Instantiate modules in the space | |||
| """ | |||
| if not self._initialized: | |||
| self._initialized = True | |||
| raise NotImplementedError() | |||
| @abstractmethod | |||
| def forward(self, *args, **kwargs): | |||
| @@ -53,3 +66,31 @@ class BaseSpace(nn.Module): | |||
| model to be exported. | |||
| """ | |||
| raise NotImplementedError() | |||
| def instantiate(self, *args, **kwargs): | |||
| self._default_key = 0 | |||
| self._instantiate(*args, **kwargs) | |||
| if not self._initialized: | |||
| self._initialized = True | |||
| def setLayerChoice(self, *args, **kwargs): | |||
| """ | |||
| Give a unique key if not given | |||
| """ | |||
| if len(args) < 5 and not "key" in kwargs: | |||
| key = f"default_key_{self._default_key}" | |||
| self._default_key += 1 | |||
| kwargs["key"] = key | |||
| layer = OrderedLayerChoice(*args, **kwargs) | |||
| return layer | |||
| def setInputChoice(self, *args, **kwargs): | |||
| """ | |||
| Give a unique key if not given | |||
| """ | |||
| if len(args) < 7 and not "key" in kwargs: | |||
| key = f"default_key_{self._default_key}" | |||
| self._default_key += 1 | |||
| kwargs["key"] = key | |||
| layer = OrderedInputChoice(*args, **kwargs) | |||
| return layer | |||
| @@ -71,7 +71,7 @@ class SinglePathNodeClassificationSpace(BaseSpace): | |||
| self.ops = ops | |||
| self.dropout = dropout | |||
| def instantiate( | |||
| def _instantiate( | |||
| self, | |||
| hidden_dim: _typ.Optional[int] = None, | |||
| layer_number: _typ.Optional[int] = None, | |||
| @@ -89,7 +89,8 @@ class SinglePathNodeClassificationSpace(BaseSpace): | |||
| setattr( | |||
| self, | |||
| f"op_{layer}", | |||
| mutables.LayerChoice( | |||
| self.setLayerChoice( | |||
| layer, | |||
| [ | |||
| op( | |||
| self.input_dim if layer == 0 else self.hidden_dim, | |||
| @@ -99,7 +100,6 @@ class SinglePathNodeClassificationSpace(BaseSpace): | |||
| ) | |||
| for op in self.ops | |||
| ], | |||
| key=f"{layer}", | |||
| ), | |||
| ) | |||
| self._initialized = True | |||
| @@ -7,7 +7,7 @@ from collections import OrderedDict | |||
| import numpy as np | |||
| import torch | |||
| import nni.retiarii.nn.pytorch as nn | |||
| from nni.nas.pytorch.mutables import InputChoice, LayerChoice | |||
| from nni.nas.pytorch.mutables import Mutable, InputChoice, LayerChoice | |||
| _logger = logging.getLogger(__name__) | |||
| @@ -123,6 +123,21 @@ class AverageMeter: | |||
| fmtstr = "{name}: {avg" + self.fmt + "}" | |||
| return fmtstr.format(**self.__dict__) | |||
| def get_module_order(root_module): | |||
| key2order = {} | |||
| def apply(m): | |||
| for name, child in m.named_children(): | |||
| if isinstance(child, Mutable): | |||
| key2order[child.key] = child.order | |||
| else: | |||
| apply(child) | |||
| apply(root_module) | |||
| return key2order | |||
| def sort_replaced_module(k2o, modules): | |||
| modules = sorted(modules, key = lambda x:k2o[x[0]]) | |||
| return modules | |||
| def _replace_module_with_type(root_module, init_fn, type_name, modules): | |||
| if modules is None: | |||
| @@ -31,9 +31,9 @@ if __name__ == '__main__': | |||
| feval=['acc'], | |||
| loss="nll_loss", | |||
| lr_scheduler_type=None,), | |||
| #nas_algorithms=[Enas()], | |||
| nas_algorithms=[Darts(num_epochs=1)], | |||
| nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GATConv])], | |||
| nas_algorithms=[Enas()], | |||
| #nas_algorithms=[Darts(num_epochs=1)], | |||
| nas_spaces=[SinglePathNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv])], | |||
| nas_estimators=[OneShotEstimator()] | |||
| ) | |||
| solver.fit(dataset) | |||