From 9536982cf3aa472dcda8a5046e41e6a69b091af1 Mon Sep 17 00:00:00 2001 From: wondergo2017 Date: Thu, 19 Aug 2021 15:46:14 +0800 Subject: [PATCH] add con ops control for graphnasspace --- autogl/module/nas/space/graph_nas.py | 30 ++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/autogl/module/nas/space/graph_nas.py b/autogl/module/nas/space/graph_nas.py index e70ef2d..7bc32d7 100644 --- a/autogl/module/nas/space/graph_nas.py +++ b/autogl/module/nas/space/graph_nas.py @@ -39,6 +39,8 @@ GRAPHNAS_DEFAULT_ACT_OPS = [ "elu", ] +GRAPHNAS_DEFAULT_CON_OPS=["add", "product", "concat"] +# GRAPHNAS_DEFAULT_CON_OPS=[ "concat"] # for darts class LambdaModule(nn.Module): def __init__(self, lambd): @@ -83,6 +85,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): output_dim: _typ.Optional[int] = None, gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_GNN_OPS, act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_ACT_OPS, + con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_CON_OPS ): super().__init__() self.layer_number = layer_number @@ -91,6 +94,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): self.output_dim = output_dim self.gnn_ops = gnn_ops self.act_ops = act_ops + self.con_ops = con_ops self.dropout = dropout def instantiate( @@ -102,6 +106,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): output_dim: _typ.Optional[int] = None, gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, + con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, ): super().instantiate() self.dropout = dropout or self.dropout @@ -111,6 +116,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): self.output_dim = output_dim or self.output_dim self.gnn_ops = gnn_ops or self.gnn_ops self.act_ops = act_ops or self.act_ops + self.con_ops = con_ops or self.con_ops self.preproc0 = nn.Linear(self.input_dim, self.hidden_dim) self.preproc1 = nn.Linear(self.input_dim, self.hidden_dim) node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] @@ -146,13 +152,15 @@ class GraphNasNodeClassificationSpace(BaseSpace): 2 * layer, [act_map_nn(a) for a in self.act_ops], key="act" ), ) - setattr( - self, - "concat", - self.setLayerChoice( - 2 * layer + 1, map_nn(["add", "product", "concat"]), key="concat" - ), - ) + # for DARTS, len(con_ops) can only <=1, for dimension problems + if len(self.con_ops)>1: + setattr( + self, + "concat", + self.setLayerChoice( + 2 * layer + 1, map_nn(self.con_ops), key="concat" + ), + ) self._initialized = True self.classifier1 = nn.Linear( self.hidden_dim * self.layer_number, self.output_dim @@ -172,7 +180,13 @@ class GraphNasNodeClassificationSpace(BaseSpace): node_out = bk_gconv(op,data,node_in) prev_nodes_out.append(node_out) act = getattr(self, "act") - con = getattr(self, "concat")() + if len(self.con_ops)>1: + con = getattr(self, "concat")() + elif len(self.con_ops)==1: + con=self.con_ops[0] + else: + con="concat" + states = prev_nodes_out if con == "concat": x = torch.cat(states[2:], dim=1)