From 0bb9909b020931777abab3e4e896bd83f585e017 Mon Sep 17 00:00:00 2001 From: Frozenmad Date: Tue, 28 Dec 2021 16:57:35 +0800 Subject: [PATCH] adjust performance --- .../pyg/link_prediction_base.py | 24 +++------------ .../pyg/link_prediction_model.py | 29 +++---------------- 2 files changed, 8 insertions(+), 45 deletions(-) diff --git a/test/performance/link_prediction/pyg/link_prediction_base.py b/test/performance/link_prediction/pyg/link_prediction_base.py index 8435272..01211d8 100644 --- a/test/performance/link_prediction/pyg/link_prediction_base.py +++ b/test/performance/link_prediction/pyg/link_prediction_base.py @@ -124,22 +124,7 @@ args = parser.parse_args() args.device = torch.device('cuda:0') device = torch.device('cuda:0') -args.dataset = 'Cora' -args.model = 'gcn' -print(args.dataset) -print(args.model) -# load the dataset - -# path = osp.join('.', 'data', args.dataset) -path = osp.join('data', args.dataset) -if args.dataset == 'Cora': - dataset = Planetoid(path, name='Cora',transform=T.NormalizeFeatures()) -elif args.dataset == 'CiteSeer': - dataset = Planetoid(path, name='CiteSeer',transform=T.NormalizeFeatures()) -elif args.dataset == 'PubMed': - dataset = Planetoid(path, name='PubMed',transform=T.NormalizeFeatures()) -else: - assert False +dataset = Planetoid(osp.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures()) def train(): model.train() @@ -173,13 +158,13 @@ def test(): model.eval() perfs = [] for prefix in ["val", "test"]: - print(prefix) + # print(prefix) pos_edge_index = data[f'{prefix}_pos_edge_index'] neg_edge_index = data[f'{prefix}_neg_edge_index'] z = model.encode(data) # encode train - print("testen_shape",data.x.shape, data.train_pos_edge_index.shape) - print("testde_shape",z.shape, data.train_pos_edge_index.shape,neg_edge_index.shape) + # print("testen_shape",data.x.shape, data.train_pos_edge_index.shape) + # print("testde_shape",z.shape, data.train_pos_edge_index.shape,neg_edge_index.shape) # val # testen_shape torch.Size([2708, 1433]) torch.Size([2, 8976]) # testde_shape torch.Size([2708, 64]) torch.Size([2, 8976]) torch.Size([2, 263]) @@ -204,7 +189,6 @@ for seed in tqdm(range(1234, 1234+args.repeat)): if args.model == 'gcn': model = GCN(dataset.num_features, 128).to(device) - print(model) elif args.model == 'gat': model = GAT(dataset.num_features, 128).to(device) elif args.model == 'sage': diff --git a/test/performance/link_prediction/pyg/link_prediction_model.py b/test/performance/link_prediction/pyg/link_prediction_model.py index 40b7eaa..604d594 100644 --- a/test/performance/link_prediction/pyg/link_prediction_model.py +++ b/test/performance/link_prediction/pyg/link_prediction_model.py @@ -49,22 +49,7 @@ args = parser.parse_args() args.device = torch.device('cuda:0') device = torch.device('cuda:0') -args.dataset = 'Cora' -args.model = 'gcn' -print(args.dataset) -print(args.model) -# load the dataset - -# path = osp.join('.', 'data', args.dataset) -path = osp.join('data', args.dataset) -if args.dataset == 'Cora': - dataset = Planetoid(path, name='Cora',transform=T.NormalizeFeatures()) -elif args.dataset == 'CiteSeer': - dataset = Planetoid(path, name='CiteSeer',transform=T.NormalizeFeatures()) -elif args.dataset == 'PubMed': - dataset = Planetoid(path, name='PubMed',transform=T.NormalizeFeatures()) -else: - assert False +dataset = Planetoid(osp.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures()) def train(data): model.train() @@ -125,7 +110,7 @@ for seed in tqdm(range(1234, 1234+args.repeat)): data.train_mask = data.val_mask = data.test_mask = data.y = None data = train_test_split_edges(data).to(device) if args.model == 'gcn': - model = AutoGCN(dataset=dataset, + model = AutoGCN( num_features=dataset.num_features, num_classes=2, # num_class对linkpre任务似乎没有用? device=args.device, @@ -134,10 +119,7 @@ for seed in tqdm(range(1234, 1234+args.repeat)): 'num_layers': 3, 'hidden': [128,64], 'dropout': 0.0, - 'act': 'relu', # 对linkpre任务似乎没有用? - 'agg': 'mean', - 'add_self_loops': 'False', - 'normalize': 'False', + 'act': '' }).model elif args.model == 'gat': model = AutoGAT(dataset=dataset, @@ -149,10 +131,7 @@ for seed in tqdm(range(1234, 1234+args.repeat)): 'num_layers': 3, 'hidden': [128,64], 'dropout': 0.0, - 'act': 'relu', - 'agg': 'mean', - 'add_self_loops': 'False', - 'normalize': 'False', + 'act': 'relu' }).model elif args.model == 'sage': model = AutoSAGE(dataset=dataset,