|
|
|
@@ -124,22 +124,7 @@ args = parser.parse_args() |
|
|
|
args.device = torch.device('cuda:0') |
|
|
|
device = torch.device('cuda:0') |
|
|
|
|
|
|
|
args.dataset = 'Cora' |
|
|
|
args.model = 'gcn' |
|
|
|
print(args.dataset) |
|
|
|
print(args.model) |
|
|
|
# load the dataset |
|
|
|
|
|
|
|
# path = osp.join('.', 'data', args.dataset) |
|
|
|
path = osp.join('data', args.dataset) |
|
|
|
if args.dataset == 'Cora': |
|
|
|
dataset = Planetoid(path, name='Cora',transform=T.NormalizeFeatures()) |
|
|
|
elif args.dataset == 'CiteSeer': |
|
|
|
dataset = Planetoid(path, name='CiteSeer',transform=T.NormalizeFeatures()) |
|
|
|
elif args.dataset == 'PubMed': |
|
|
|
dataset = Planetoid(path, name='PubMed',transform=T.NormalizeFeatures()) |
|
|
|
else: |
|
|
|
assert False |
|
|
|
dataset = Planetoid(osp.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures()) |
|
|
|
|
|
|
|
def train(): |
|
|
|
model.train() |
|
|
|
@@ -173,13 +158,13 @@ def test(): |
|
|
|
model.eval() |
|
|
|
perfs = [] |
|
|
|
for prefix in ["val", "test"]: |
|
|
|
print(prefix) |
|
|
|
# print(prefix) |
|
|
|
pos_edge_index = data[f'{prefix}_pos_edge_index'] |
|
|
|
neg_edge_index = data[f'{prefix}_neg_edge_index'] |
|
|
|
|
|
|
|
z = model.encode(data) # encode train |
|
|
|
print("testen_shape",data.x.shape, data.train_pos_edge_index.shape) |
|
|
|
print("testde_shape",z.shape, data.train_pos_edge_index.shape,neg_edge_index.shape) |
|
|
|
# print("testen_shape",data.x.shape, data.train_pos_edge_index.shape) |
|
|
|
# print("testde_shape",z.shape, data.train_pos_edge_index.shape,neg_edge_index.shape) |
|
|
|
# val |
|
|
|
# testen_shape torch.Size([2708, 1433]) torch.Size([2, 8976]) |
|
|
|
# testde_shape torch.Size([2708, 64]) torch.Size([2, 8976]) torch.Size([2, 263]) |
|
|
|
@@ -204,7 +189,6 @@ for seed in tqdm(range(1234, 1234+args.repeat)): |
|
|
|
|
|
|
|
if args.model == 'gcn': |
|
|
|
model = GCN(dataset.num_features, 128).to(device) |
|
|
|
print(model) |
|
|
|
elif args.model == 'gat': |
|
|
|
model = GAT(dataset.num_features, 128).to(device) |
|
|
|
elif args.model == 'sage': |
|
|
|
|