Browse Source

fix zeroconv bug

tags/v0.3.1
wondergo2017 5 years ago
parent
commit
3cd0367a23
4 changed files with 41 additions and 8 deletions
  1. +20
    -3
      autogl/module/nas/algorithm/enas.py
  2. +17
    -3
      autogl/module/nas/space/graph_nas.py
  3. +2
    -0
      autogl/module/nas/space/single_path.py
  4. +2
    -2
      examples/test_graph_nas.py

+ 20
- 3
autogl/module/nas/algorithm/enas.py View File

@@ -10,6 +10,7 @@ from .base import BaseNAS
from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module
from nni.nas.pytorch.fixed import apply_fixed_architecture
from tqdm import tqdm
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
@@ -312,9 +313,21 @@ class Enas(BaseNAS):
self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {}))
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr)
# train
for i in range(self.num_epochs):
self._train_model(i)
self._train_controller(i)
with tqdm(range(self.num_epochs)) as bar:
for i in bar:
try:
l1=self._train_model(i)
l2=self._train_controller(i)
except Exception as e:
print(e)
nm=self.nas_modules
for i in range(len(nm)):
print(nm[i][1].sampled)
import pdb
pdb.set_trace()

bar.set_postfix(loss_model=l1,reward_controller=l2)
selection=self.export()
return space.export(selection,self.device)
@@ -329,16 +342,19 @@ class Enas(BaseNAS):
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.model_optim.step()
return loss.item()

def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
rewards=[]
for ctrl_step in range(self.ctrl_steps_aggregate):
self._resample()
with torch.no_grad():
metric,loss=self._infer()
reward =-metric # todo : now metric is loss
rewards.append(reward)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
@@ -357,6 +373,7 @@ class Enas(BaseNAS):
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs,
ctrl_step + 1, self.ctrl_steps_aggregate)
return (sum(rewards)/len(rewards)).item()

def _resample(self):
result = self.controller.resample()


+ 17
- 3
autogl/module/nas/space/graph_nas.py View File

@@ -43,6 +43,9 @@ class LambdaModule(nn.Module):

def forward(self, x):
return self.lambd(x)
def __repr__(self):
return '{}({})'.format(self.__class__.__name__,self.lambd)
class StrModule(nn.Module):
def __init__(self, lambd):
super().__init__()
@@ -50,6 +53,9 @@ class StrModule(nn.Module):

def forward(self, *args,**kwargs):
return self.str

def __repr__(self):
return '{}({})'.format(self.__class__.__name__,self.str)
def act_map(act):
if act == "linear":
return lambda x: x
@@ -128,6 +134,15 @@ class LinearConv(nn.Module):
self.out_channels)


from torch.autograd import Function
class ZeroConvFunc(Function):
@staticmethod
def forward(ctx,x):
return x

@staticmethod
def backward(ctx, grad_output):
return grad_output
class ZeroConv(nn.Module):
def __init__(self,
in_channels,
@@ -138,9 +153,8 @@ class ZeroConv(nn.Module):
self.out_channels = out_channels
self.out_dim = out_channels


def forward(self, x, edge_index, edge_weight=None):
return torch.zeros([x.size(0), self.out_dim]).to(x.device)
return ZeroConvFunc.apply(torch.zeros([x.size(0), self.out_dim]).to(x.device))

def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
@@ -202,7 +216,7 @@ class GraphNasNodeClassificationSpace(BaseSpace):
node_in = getattr(self, f"in_{layer}")(prev_nodes_out)
node_out= getattr(self, f"op_{layer}")(node_in,edges)
prev_nodes_out.append(node_out)
if self.search_act_con:
if not self.search_act_con:
x = torch.cat(prev_nodes_out[2:],dim=1)
x = F.leaky_relu(x)
x = F.dropout(x, p=self.dropout, training = self.training)


+ 2
- 0
autogl/module/nas/space/single_path.py View File

@@ -27,6 +27,8 @@ class FixedNodeClassificationModel(BaseModel):
apply_fixed_architecture(self._model, selection, verbose=False)
self.params = {"num_class": self.num_classes, "features_num": self.num_features}
self.device = device
print(self._model)
print(selection)

def to(self, device):
if isinstance(device, (str, torch.device)):


+ 2
- 2
examples/test_graph_nas.py View File

@@ -28,9 +28,9 @@ if __name__ == '__main__':
feval=['acc'],
loss="nll_loss",
lr_scheduler_type=None,),
nas_algorithms=[Enas(num_epochs=10)],
nas_algorithms=[Enas(num_epochs=100)],
#nas_algorithms=[Darts(num_epochs=200)],
nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv],search_act_con=True)],
nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16,search_act_con=False)],
nas_estimators=[OneShotEstimator()]
)
solver.fit(dataset)


Loading…
Cancel
Save