Browse Source

add graph nas space. test ok for enas

tags/v0.3.1
wondergo2017 5 years ago
parent
commit
a69f60750d
3 changed files with 265 additions and 1 deletions
  1. +1
    -1
      autogl/module/nas/algorithm/enas.py
  2. +225
    -0
      autogl/module/nas/space/graph_nas.py
  3. +39
    -0
      examples/test_graph_nas.py

+ 1
- 1
autogl/module/nas/algorithm/enas.py View File

@@ -273,7 +273,7 @@ class Enas(BaseNAS):
def __init__(self, device='cuda', workers=4,log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None,*args,**kwargs):
super().__init__(*args,**kwargs)
super().__init__(device)
self.device=device
self.num_epochs = kwargs.get("num_epochs", 5)
self.workers = workers


+ 225
- 0
autogl/module/nas/space/graph_nas.py View File

@@ -0,0 +1,225 @@
from copy import deepcopy
import typing as _typ
import torch

import torch.nn.functional as F
from nni.nas.pytorch import mutables
from nni.nas.pytorch.fixed import apply_fixed_architecture
from .base import BaseSpace
from ...model import BaseModel
from ....utils import get_logger

from ...model import AutoGCN
from .single_path import FixedNodeClassificationModel
from .base import OrderedLayerChoice,OrderedInputChoice
from torch import nn

from torch_geometric.nn.conv import *
from pdb import set_trace
gnn_list = [
"gat_8", # GAT with 8 heads
"gat_6", # GAT with 6 heads
"gat_4", # GAT with 4 heads
"gat_2", # GAT with 2 heads
"gat_1", # GAT with 1 heads
"gcn", # GCN
"cheb", # chebnet
"sage", # sage
"arma",
"sg", # simplifying gcn
"linear", # skip connection
"zero", # skip connection
]
act_list = [
# "sigmoid", "tanh", "relu", "linear",
# "softplus", "leaky_relu", "relu6", "elu"
"sigmoid", "tanh", "relu", "linear", "elu"
]

class LambdaModule(nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd

def forward(self, x):
return self.lambd(x)
class StrModule(nn.Module):
def __init__(self, lambd):
super().__init__()
self.str = lambd

def forward(self, *args,**kwargs):
return self.str
def act_map(act):
if act == "linear":
return lambda x: x
elif act == "elu":
return F.elu
elif act == "sigmoid":
return torch.sigmoid
elif act == "tanh":
return torch.tanh
elif act == "relu":
return torch.nn.functional.relu
elif act == "relu6":
return torch.nn.functional.relu6
elif act == "softplus":
return torch.nn.functional.softplus
elif act == "leaky_relu":
return torch.nn.functional.leaky_relu
else:
raise Exception("wrong activate function")
def act_map_nn(act):
return LambdaModule(act_map(act))
def map_nn(l):
return [StrModule(x) for x in l]

def gnn_map(gnn_name, in_dim, out_dim, concat=False, bias=True) -> nn.Module:
'''

:param gnn_name:
:param in_dim:
:param out_dim:
:param concat: for gat, concat multi-head output or not
:return: GNN model
'''
if gnn_name == "gat_8":
return GATConv(in_dim, out_dim, 8, concat=concat, bias=bias)
elif gnn_name == "gat_6":
return GATConv(in_dim, out_dim, 6, concat=concat, bias=bias)
elif gnn_name == "gat_4":
return GATConv(in_dim, out_dim, 4, concat=concat, bias=bias)
elif gnn_name == "gat_2":
return GATConv(in_dim, out_dim, 2, concat=concat, bias=bias)
elif gnn_name in ["gat_1", "gat"]:
return GATConv(in_dim, out_dim, 1, concat=concat, bias=bias)
elif gnn_name == "gcn":
return GCNConv(in_dim, out_dim)
elif gnn_name == "cheb":
return ChebConv(in_dim, out_dim, K=2, bias=bias)
elif gnn_name == "sage":
return SAGEConv(in_dim, out_dim, bias=bias)
elif gnn_name == "gated":
return GatedGraphConv(in_dim, out_dim, bias=bias)
elif gnn_name == "arma":
return ARMAConv(in_dim, out_dim, bias=bias)
elif gnn_name == "sg":
return SGConv(in_dim, out_dim, bias=bias)
elif gnn_name == "linear":
return LinearConv(in_dim, out_dim, bias=bias)
elif gnn_name == "zero":
return ZeroConv(in_dim, out_dim, bias=bias)
class LinearConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True):
super(LinearConv, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.linear = torch.nn.Linear(in_channels, out_channels, bias)

def forward(self, x, edge_index, edge_weight=None):
return self.linear(x)

def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)


class ZeroConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
bias=True):
super(ZeroConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.out_dim = out_channels


def forward(self, x, edge_index, edge_weight=None):
return torch.zeros([x.size(0), self.out_dim]).to(x.device)

def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)

class GraphNasNodeClassificationSpace(BaseSpace):
def __init__(
self,
hidden_dim: _typ.Optional[int] = 64,
layer_number: _typ.Optional[int] = 2,
dropout: _typ.Optional[float] = 0.2,
input_dim: _typ.Optional[int] = None,
output_dim: _typ.Optional[int] = None,
ops: _typ.Tuple = None,
init: bool = False,
):
super().__init__()
self.layer_number = layer_number
self.hidden_dim = hidden_dim
self.input_dim = input_dim
self.output_dim = output_dim
self.ops = ops
self.dropout = dropout

def _instantiate(
self,
hidden_dim: _typ.Optional[int] = None,
layer_number: _typ.Optional[int] = None,
input_dim: _typ.Optional[int] = None,
output_dim: _typ.Optional[int] = None,
ops: _typ.Tuple = None,
dropout = None
):
self.hidden_dim = hidden_dim or self.hidden_dim
self.layer_number = layer_number or self.layer_number
self.input_dim = input_dim or self.input_dim
self.output_dim = output_dim or self.output_dim
self.ops = ops or self.ops
self.dropout = dropout or self.dropout
self.preproc0 = nn.Linear(self.input_dim, self.hidden_dim)
self.preproc1 = nn.Linear(self.input_dim, self.hidden_dim)
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
for layer in range(2,self.layer_number+2):
node_labels.append(f"op_{layer}")
setattr(self,f"in_{layer}",self.setInputChoice(layer,choose_from=node_labels[:-1], n_chosen=1, return_mask=False,key=f"in_{layer}"))
setattr(self,f"op_{layer}",self.setLayerChoice(layer,[gnn_map(op,self.hidden_dim,self.hidden_dim)for op in gnn_list],key=f"op_{layer}"))
# setattr(self,f"act",self.setLayerChoice(2*layer,[act_map_nn(a)for a in act_list],key=f"act"))
# setattr(self,f"concat",self.setLayerChoice(2*layer+1,map_nn(["add", "product", "concat"]) ,key=f"concat"))
self._initialized = True

def forward(self, data):
x, edges = data.x, data.edge_index # x [2708,1433] ,[2, 10556]
pprev_, prev_ = self.preproc0(x), self.preproc1(x)
prev_nodes_out = [pprev_,prev_]
for layer in range(2,self.layer_number+2):
node_in = getattr(self, f"in_{layer}")(prev_nodes_out)
node_out= getattr(self, f"op_{layer}")(node_in,edges)
prev_nodes_out.append(node_out)
x = torch.cat(prev_nodes_out[2:],dim=1)
x = F.leaky_relu(x)
# x = F.dropout(x, p=self.dropout, training = self.training)
if False:
act=getattr(self, f"act")
con=getattr(self, f"concat")()
states=prev_nodes_out
if con == "concat":
x=torch.cat(states[2:], dim=1)
else:
tmp = states[2]
for i in range(2,len(states)):
if con == "add":
tmp = torch.add(tmp, states[i])
elif con == "product":
tmp = torch.mul(tmp, states[i])
x=tmp
x = act(x)
x = F.dropout(x, p=self.dropout, training = self.training)
return F.log_softmax(x, dim=1)

def export(self, selection, device) -> BaseModel:
#return AutoGCN(self.input_dim, self.output_dim, device)
return FixedNodeClassificationModel(self, selection, device)

+ 39
- 0
examples/test_graph_nas.py View File

@@ -0,0 +1,39 @@
import sys
sys.path.append('../')
from torch_geometric.nn import GCNConv
import torch
from autogl.datasets import build_dataset_from_name
from autogl.solver import AutoNodeClassifier
from autogl.module.train import NodeClassificationFullTrainer
from autogl.module.nas import Darts, OneShotEstimator
from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
from autogl.module.train import Acc
from autogl.module.nas.algorithm.enas import Enas

if __name__ == '__main__':
dataset = build_dataset_from_name('cora')
solver = AutoNodeClassifier(
feature_module='PYGNormalizeFeatures',
graph_models=[],
hpo_module=None,
ensemble_module=None,
default_trainer=NodeClassificationFullTrainer(
optimizer=torch.optim.Adam,
lr=0.01,
max_epoch=200,
early_stopping_round=200,
weight_decay=5e-4,
device="auto",
init=False,
feval=['acc'],
loss="nll_loss",
lr_scheduler_type=None,),
nas_algorithms=[Enas(num_epochs=10)],
#nas_algorithms=[Darts(num_epochs=200)],
nas_spaces=[GraphNasNodeClassificationSpace(hidden_dim=16, ops=[GCNConv, GCNConv])],
nas_estimators=[OneShotEstimator()]
)
solver.fit(dataset)
solver.get_leaderboard().show()
out = solver.predict_proba()
print('acc on cora', Acc.evaluate(out, dataset[0].y[dataset[0].test_mask].detach().numpy()))

Loading…
Cancel
Save