| @@ -39,6 +39,8 @@ GRAPHNAS_DEFAULT_ACT_OPS = [ | |||
| "elu", | |||
| ] | |||
| GRAPHNAS_DEFAULT_CON_OPS=["add", "product", "concat"] | |||
| # GRAPHNAS_DEFAULT_CON_OPS=[ "concat"] # for darts | |||
| class LambdaModule(nn.Module): | |||
| def __init__(self, lambd): | |||
| @@ -83,6 +85,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| output_dim: _typ.Optional[int] = None, | |||
| gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_GNN_OPS, | |||
| act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_ACT_OPS, | |||
| con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = GRAPHNAS_DEFAULT_CON_OPS | |||
| ): | |||
| super().__init__() | |||
| self.layer_number = layer_number | |||
| @@ -91,6 +94,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| self.output_dim = output_dim | |||
| self.gnn_ops = gnn_ops | |||
| self.act_ops = act_ops | |||
| self.con_ops = con_ops | |||
| self.dropout = dropout | |||
| def instantiate( | |||
| @@ -102,6 +106,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| output_dim: _typ.Optional[int] = None, | |||
| gnn_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, | |||
| act_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, | |||
| con_ops: _typ.Sequence[_typ.Union[str, _typ.Any]] = None, | |||
| ): | |||
| super().instantiate() | |||
| self.dropout = dropout or self.dropout | |||
| @@ -111,6 +116,7 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| self.output_dim = output_dim or self.output_dim | |||
| self.gnn_ops = gnn_ops or self.gnn_ops | |||
| self.act_ops = act_ops or self.act_ops | |||
| self.con_ops = con_ops or self.con_ops | |||
| self.preproc0 = nn.Linear(self.input_dim, self.hidden_dim) | |||
| self.preproc1 = nn.Linear(self.input_dim, self.hidden_dim) | |||
| node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] | |||
| @@ -146,13 +152,15 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| 2 * layer, [act_map_nn(a) for a in self.act_ops], key="act" | |||
| ), | |||
| ) | |||
| setattr( | |||
| self, | |||
| "concat", | |||
| self.setLayerChoice( | |||
| 2 * layer + 1, map_nn(["add", "product", "concat"]), key="concat" | |||
| ), | |||
| ) | |||
| # for DARTS, len(con_ops) can only <=1, for dimension problems | |||
| if len(self.con_ops)>1: | |||
| setattr( | |||
| self, | |||
| "concat", | |||
| self.setLayerChoice( | |||
| 2 * layer + 1, map_nn(self.con_ops), key="concat" | |||
| ), | |||
| ) | |||
| self._initialized = True | |||
| self.classifier1 = nn.Linear( | |||
| self.hidden_dim * self.layer_number, self.output_dim | |||
| @@ -172,7 +180,13 @@ class GraphNasNodeClassificationSpace(BaseSpace): | |||
| node_out = bk_gconv(op,data,node_in) | |||
| prev_nodes_out.append(node_out) | |||
| act = getattr(self, "act") | |||
| con = getattr(self, "concat")() | |||
| if len(self.con_ops)>1: | |||
| con = getattr(self, "concat")() | |||
| elif len(self.con_ops)==1: | |||
| con=self.con_ops[0] | |||
| else: | |||
| con="concat" | |||
| states = prev_nodes_out | |||
| if con == "concat": | |||
| x = torch.cat(states[2:], dim=1) | |||