Browse Source

add con ops control for graphnasspace

tags/v0.3.1
wondergo2017 4 years ago
parent
commit
9536982cf3
1 changed files with 22 additions and 8 deletions
  1. +22
    -8
      autogl/module/nas/space/graph_nas.py

+ 22
- 8
autogl/module/nas/space/graph_nas.py View File

@@ -39,6 +39,8 @@ GRAPHNAS_DEFAULT_ACT_OPS = [
"elu",
]

GRAPHNAS_DEFAULT_CON_OPS=["add", "product", "concat"]
# GRAPHNAS_DEFAULT_CON_OPS=[ "concat"] # for darts

class LambdaModule(nn.Module):
def __init__(self, lambd):
@@ -83,6 +85,7 @@ class GraphNasNodeClassificationSpace(BaseSpace):
output_dim: _typ.Optional[int] = None,
gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_GNN_OPS,
act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_ACT_OPS,
con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_CON_OPS
):
super().__init__()
self.layer_number = layer_number
@@ -91,6 +94,7 @@ class GraphNasNodeClassificationSpace(BaseSpace):
self.output_dim = output_dim
self.gnn_ops = gnn_ops
self.act_ops = act_ops
self.con_ops = con_ops
self.dropout = dropout

def instantiate(
@@ -102,6 +106,7 @@ class GraphNasNodeClassificationSpace(BaseSpace):
output_dim: _typ.Optional[int] = None,
gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None,
act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None,
con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None,
):
super().instantiate()
self.dropout = dropout or self.dropout
@@ -111,6 +116,7 @@ class GraphNasNodeClassificationSpace(BaseSpace):
self.output_dim = output_dim or self.output_dim
self.gnn_ops = gnn_ops or self.gnn_ops
self.act_ops = act_ops or self.act_ops
self.con_ops = con_ops or self.con_ops
self.preproc0 = nn.Linear(self.input_dim, self.hidden_dim)
self.preproc1 = nn.Linear(self.input_dim, self.hidden_dim)
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
@@ -146,13 +152,15 @@ class GraphNasNodeClassificationSpace(BaseSpace):
2 * layer, [act_map_nn(a) for a in self.act_ops], key="act"
),
)
setattr(
self,
"concat",
self.setLayerChoice(
2 * layer + 1, map_nn(["add", "product", "concat"]), key="concat"
),
)
# for DARTS, len(con_ops) can only <=1, for dimension problems
if len(self.con_ops)>1:
setattr(
self,
"concat",
self.setLayerChoice(
2 * layer + 1, map_nn(self.con_ops), key="concat"
),
)
self._initialized = True
self.classifier1 = nn.Linear(
self.hidden_dim * self.layer_number, self.output_dim
@@ -172,7 +180,13 @@ class GraphNasNodeClassificationSpace(BaseSpace):
node_out = bk_gconv(op,data,node_in)
prev_nodes_out.append(node_out)
act = getattr(self, "act")
con = getattr(self, "concat")()
if len(self.con_ops)>1:
con = getattr(self, "concat")()
elif len(self.con_ops)==1:
con=self.con_ops[0]
else:
con="concat"

states = prev_nodes_out
if con == "concat":
x = torch.cat(states[2:], dim=1)


Loading…
Cancel
Save