From f9b08701bc3fb02490eaa0aa405aaff4d1b28afa Mon Sep 17 00:00:00 2001 From: Beini Date: Fri, 31 Dec 2021 02:55:23 +0000 Subject: [PATCH] hgt hidden --- autogl/module/model/dgl/hetero/hgt.py | 14 +++++++++----- test/performance/heterogeneous/dgl/hgt_main.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/autogl/module/model/dgl/hetero/hgt.py b/autogl/module/model/dgl/hetero/hgt.py index b222e6a..72f8806 100644 --- a/autogl/module/model/dgl/hetero/hgt.py +++ b/autogl/module/model/dgl/hetero/hgt.py @@ -157,16 +157,20 @@ class HGT(nn.Module): if not self.num_layers == len(self.args["hidden"]): LOGGER.warn("layer size {} does not match the length of hidden units {}".format(self.num_layers, len(self.args["hidden"]))) + + hidden_size = self.args["hidden"][0] + hidden_size = hidden_size//self.args["heads"]*self.args["heads"] + LOGGER.warn('only use the first hidden size={} (divided exactly the number of heads) for all HGT layers'.format(hidden_size)) self.adapt_ws = nn.ModuleList() for t in range(len(self.node_dict)): - self.adapt_ws.append(nn.Linear(self.args["features_num"], self.args["hidden"][0])) + self.adapt_ws.append(nn.Linear(self.args["features_num"], hidden_size)) - for i in range(1, self.num_layers): - self.gcs.append(HGTLayer(self.args["hidden"][i - 1], self.args["hidden"][i], self.node_dict, self.edge_dict, \ + for i in range(self.num_layers): + self.gcs.append(HGTLayer(hidden_size, hidden_size, self.node_dict, self.edge_dict, \ self.args["heads"], use_norm = self.args["use_norm"], dropout = self.args["dropout"])) - self.out = nn.Linear(self.args["hidden"][-1], self.args["num_class"]) + self.out = nn.Linear(hidden_size, self.args["num_class"]) def forward(self, G): h = {} @@ -250,7 +254,7 @@ class AutoHGT(BaseHeteroModelMaintainer): self.hyper_parameters = { "num_layers": 2, - "hidden": [256], + "hidden": [256,256,256], "heads": 4, "dropout": 0.2, "act": "gelu", diff --git a/test/performance/heterogeneous/dgl/hgt_main.py b/test/performance/heterogeneous/dgl/hgt_main.py index 2e6f458..ff48740 100644 --- a/test/performance/heterogeneous/dgl/hgt_main.py +++ b/test/performance/heterogeneous/dgl/hgt_main.py @@ -83,7 +83,7 @@ if __name__=='__main__': init=False ).from_hyper_parameter({ "num_layers": 2, - "hidden": [256,256,256], + "hidden": [256,256], "heads": 4, "dropout": 0.2, "act": "gelu",