| @@ -10,6 +10,7 @@ from .base import BaseNAS | |||
| from ..space import BaseSpace | |||
| from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module | |||
| from nni.nas.pytorch.fixed import apply_fixed_architecture | |||
| from tqdm import tqdm | |||
| _logger = logging.getLogger(__name__) | |||
| def _get_mask(sampled, total): | |||
| multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] | |||
| @@ -312,9 +313,21 @@ class Enas(BaseNAS): | |||
| self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {})) | |||
| self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) | |||
| # train | |||
| for i in range(self.num_epochs): | |||
| self._train_model(i) | |||
| self._train_controller(i) | |||
| with tqdm(range(self.num_epochs)) as bar: | |||
| for i in bar: | |||
| try: | |||
| l1=self._train_model(i) | |||
| l2=self._train_controller(i) | |||
| except Exception as e: | |||
| print(e) | |||
| nm=self.nas_modules | |||
| for i in range(len(nm)): | |||
| print(nm[i][1].sampled) | |||
| import pdb | |||
| pdb.set_trace() | |||
| bar.set_postfix(loss_model=l1,reward_controller=l2) | |||
| selection=self.export() | |||
| return space.export(selection,self.device) | |||
| @@ -329,16 +342,19 @@ class Enas(BaseNAS): | |||
| if self.grad_clip > 0: | |||
| nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) | |||
| self.model_optim.step() | |||
| return loss.item() | |||
| def _train_controller(self, epoch): | |||
| self.model.eval() | |||
| self.controller.train() | |||
| self.ctrl_optim.zero_grad() | |||
| rewards=[] | |||
| for ctrl_step in range(self.ctrl_steps_aggregate): | |||
| self._resample() | |||
| with torch.no_grad(): | |||
| metric,loss=self._infer() | |||
| reward =-metric # todo : now metric is loss | |||
| rewards.append(reward) | |||
| if self.entropy_weight: | |||
| reward += self.entropy_weight * self.controller.sample_entropy.item() | |||
| self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) | |||
| @@ -357,6 +373,7 @@ class Enas(BaseNAS): | |||
| if self.log_frequency is not None and ctrl_step % self.log_frequency == 0: | |||
| _logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, | |||
| ctrl_step + 1, self.ctrl_steps_aggregate) | |||
| return (sum(rewards)/len(rewards)).item() | |||
| def _resample(self): | |||
| result = self.controller.resample() | |||
| @@ -43,6 +43,9 @@ class LambdaModule(nn.Module): | |||
| def forward(self, x): | |||
| return self.lambd(x) | |||
| def __repr__(self): | |||
| return '{}({})'.format(self.__class__.__name__,self.lambd) | |||
| class StrModule(nn.Module): | |||
| def __init__(self, lambd): | |||
| super().__init__() | |||
| @@ -50,6 +53,9 @@ class StrModule(nn.Module): | |||
| def forward(self, *args,**kwargs): | |||
| return self.str | |||
| def __repr__(self): | |||
| return '{}({})'.format(self.__class__.__name__,self.str) | |||
| def act_map(act): | |||
| if act == "linear": | |||
| return lambda x: x | |||
| @@ -128,6 +134,15 @@ class LinearConv(nn.Module): | |||
| self.out_channels) | |||
| from torch.autograd import Function | |||
| class ZeroConvFunc(Function): | |||
| @staticmethod | |||
| def forward(ctx,x): | |||
| return x | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return grad_output | |||
| class ZeroConv(nn.Module): | |||
| def __init__(self, | |||
| in_channels, | |||
| @@ -138,9 +153,8 @@ class ZeroConv(nn.Module): | |||
| self.out_channels = out_channels | |||
| self.out_dim = out_channels | |||
| def forward(self, x, edge_index, edge_weight=None): | |||
| return torch.zeros([x.size(0), self.out_dim]).to(x.device) | |||
| return ZeroConvFunc.apply(torch.zeros([x.size(0), self.out_dim]).to(x.device)) | |||
| def __repr__(self): | |||
| return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, | |||
| @@ -202,7 +216,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| node_in = getattr(self, f"in_{layer}")(prev_nodes_out) | |||
| node_out= getattr(self, f"op_{layer}")(node_in,edges) | |||
| prev_nodes_out.append(node_out) | |||
| if self.search_act_con: | |||
| if not self.search_act_con: | |||
| x = torch.cat(prev_nodes_out[2:],dim=1) | |||
| x = F.leaky_relu(x) | |||
| x = F.dropout(x, p=self.dropout, training = self.training) | |||
| @@ -27,6 +27,8 @@ class FixedNodeClassificationModel(BaseModel): | |||
| apply_fixed_architecture(self._model, selection, verbose=False) | |||
| self.params = {"num_class": self.num_classes, "features_num": self.num_features} | |||
| self.device = device | |||
| print(self._model) | |||
| print(selection) | |||
| def to(self, device): | |||
| if isinstance(device, (str, torch.device)): | |||
| @@ -28,9 +28,9 @@ if __name__ == '__main__': | |||
| feval=['acc'], | |||
| loss="nll_loss", | |||
| lr_scheduler_type=None,), | |||
| nas_algorithms=[Enas(num_epochs=10)], | |||
| nas_algorithms=[Enas(num_epochs=100)], | |||
| #nas_algorithms=[Darts(num_epochs=200)], | |||
| nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv],search_act_con=True)], | |||
| nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16,search_act_con=False)], | |||
| nas_estimators=[OneShotEstimator()] | |||
| ) | |||
| solver.fit(dataset) | |||