From 3cd0367a236db9ec6db493fa2cf361ea1295947b Mon Sep 17 00:00:00 2001 From: wondergo2017 Date: Thu, 6 May 2021 03:42:15 +0000 Subject: [PATCH] fix zeroconv bug --- autogl/module/nas/algorithm/enas.py | 23 ++++++++++++++++++++--- autogl/module/nas/space/graph_nas.py | 20 +++++++++++++++++--- autogl/module/nas/space/single_path.py | 2 ++ examples/test_graph_nas.py | 4 ++-- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/autogl/module/nas/algorithm/enas.py b/autogl/module/nas/algorithm/enas.py index 715139a..8b66f5e 100644 --- a/autogl/module/nas/algorithm/enas.py +++ b/autogl/module/nas/algorithm/enas.py @@ -10,6 +10,7 @@ from .base import BaseNAS from ..space import BaseSpace from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module from nni.nas.pytorch.fixed import apply_fixed_architecture +from tqdm import tqdm _logger = logging.getLogger(__name__) def _get_mask(sampled, total): multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] @@ -312,9 +313,21 @@ class Enas(BaseNAS): self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {})) self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) # train - for i in range(self.num_epochs): - self._train_model(i) - self._train_controller(i) + with tqdm(range(self.num_epochs)) as bar: + for i in bar: + try: + l1=self._train_model(i) + l2=self._train_controller(i) + except Exception as e: + print(e) + nm=self.nas_modules + for i in range(len(nm)): + print(nm[i][1].sampled) + import pdb + pdb.set_trace() + + + bar.set_postfix(loss_model=l1,reward_controller=l2) selection=self.export() return space.export(selection,self.device) @@ -329,16 +342,19 @@ class Enas(BaseNAS): if self.grad_clip > 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.model_optim.step() + return loss.item() def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() + rewards=[] for ctrl_step in range(self.ctrl_steps_aggregate): self._resample() with torch.no_grad(): metric,loss=self._infer() reward =-metric # todo : now metric is loss + rewards.append(reward) if self.entropy_weight: reward += self.entropy_weight * self.controller.sample_entropy.item() self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) @@ -357,6 +373,7 @@ class Enas(BaseNAS): if self.log_frequency is not None and ctrl_step % self.log_frequency == 0: _logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, ctrl_step + 1, self.ctrl_steps_aggregate) + return (sum(rewards)/len(rewards)).item() def _resample(self): result = self.controller.resample() diff --git a/autogl/module/nas/space/graph_nas.py b/autogl/module/nas/space/graph_nas.py index 447ccc9..9e02169 100644 --- a/autogl/module/nas/space/graph_nas.py +++ b/autogl/module/nas/space/graph_nas.py @@ -43,6 +43,9 @@ class LambdaModule(nn.Module): def forward(self, x): return self.lambd(x) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__,self.lambd) class StrModule(nn.Module): def __init__(self, lambd): super().__init__() @@ -50,6 +53,9 @@ class StrModule(nn.Module): def forward(self, *args,**kwargs): return self.str + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__,self.str) def act_map(act): if act == "linear": return lambda x: x @@ -128,6 +134,15 @@ class LinearConv(nn.Module): self.out_channels) +from torch.autograd import Function +class ZeroConvFunc(Function): + @staticmethod + def forward(ctx,x): + return x + + @staticmethod + def backward(ctx, grad_output): + return grad_output class ZeroConv(nn.Module): def __init__(self, in_channels, @@ -138,9 +153,8 @@ class ZeroConv(nn.Module): self.out_channels = out_channels self.out_dim = out_channels - def forward(self, x, edge_index, edge_weight=None): - return torch.zeros([x.size(0), self.out_dim]).to(x.device) + return ZeroConvFunc.apply(torch.zeros([x.size(0), self.out_dim]).to(x.device)) def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, @@ -202,7 +216,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): node_in = getattr(self, f"in_{layer}")(prev_nodes_out) node_out= getattr(self, f"op_{layer}")(node_in,edges) prev_nodes_out.append(node_out) - if self.search_act_con: + if not self.search_act_con: x = torch.cat(prev_nodes_out[2:],dim=1) x = F.leaky_relu(x) x = F.dropout(x, p=self.dropout, training = self.training) diff --git a/autogl/module/nas/space/single_path.py b/autogl/module/nas/space/single_path.py index fcbffa2..75c9746 100644 --- a/autogl/module/nas/space/single_path.py +++ b/autogl/module/nas/space/single_path.py @@ -27,6 +27,8 @@ class FixedNodeClassificationModel(BaseModel): apply_fixed_architecture(self._model, selection, verbose=False) self.params = {"num_class": self.num_classes, "features_num": self.num_features} self.device = device + print(self._model) + print(selection) def to(self, device): if isinstance(device, (str, torch.device)): diff --git a/examples/test_graph_nas.py b/examples/test_graph_nas.py index d1e297c..49207d8 100644 --- a/examples/test_graph_nas.py +++ b/examples/test_graph_nas.py @@ -28,9 +28,9 @@ if __name__ == '__main__': feval=['acc'], loss="nll_loss", lr_scheduler_type=None,), - nas_algorithms=[Enas(num_epochs=10)], + nas_algorithms=[Enas(num_epochs=100)], #nas_algorithms=[Darts(num_epochs=200)], - nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv],search_act_con=True)], + nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16,search_act_con=False)], nas_estimators=[OneShotEstimator()] ) solver.fit(dataset)