|
- """
- Performance check of AutoGL model + PYG (trainer + dataset)
- """
- import os
- import random
- import numpy as np
- from tqdm import tqdm
-
- import torch
- import torch.nn.functional as F
- from torch_geometric.datasets import Planetoid
- import torch_geometric.transforms as T
- from torch_geometric.nn import GCNConv, GATConv, SAGEConv
- import logging
-
- logging.basicConfig(level=logging.ERROR)
-
- class GCN(torch.nn.Module):
- def __init__(self, num_features, num_classes):
- super(GCN, self).__init__()
- self.conv1 = GCNConv(num_features, 16)
- self.conv2 = GCNConv(16, num_classes)
-
- def forward(self, data):
- x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
- x = F.relu(self.conv1(x, edge_index, edge_weight))
- x = F.dropout(x, training=self.training)
- x = self.conv2(x, edge_index, edge_weight)
- return F.log_softmax(x, dim=1)
-
- class GAT(torch.nn.Module):
- def __init__(self, num_features, num_classes):
- super(GAT, self).__init__()
-
- self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
- self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False,
- dropout=0.6)
-
- def forward(self, data):
- x, edge_index = data.x, data.edge_index
- x = F.dropout(x, p=0.6, training=self.training)
- x = F.elu(self.conv1(x, edge_index))
- x = F.dropout(x, p=0.6, training=self.training)
- x = self.conv2(x, edge_index)
- return F.log_softmax(x, dim=-1)
-
- class SAGE(torch.nn.Module):
- def __init__(self, num_features, hidden_channels, num_layers, num_classes):
- super(SAGE, self).__init__()
- self.num_layers = num_layers
- self.convs = torch.nn.ModuleList()
- for i in range(num_layers):
- inc = outc = hidden_channels
- if i == 0:
- inc = num_features
- if i == num_layers - 1:
- outc = num_classes
- self.convs.append(SAGEConv(inc, outc))
-
- def forward(self, data):
- x, edge_index = data.x, data.edge_index
- for i, conv in enumerate(self.convs):
- x = conv(x, edge_index)
- if i != self.num_layers - 1:
- x = x.relu()
- x = F.dropout(x, p=0.5, training=self.training)
- return F.log_softmax(x, dim=-1)
-
- def test(model, data, mask):
- model.eval()
-
- pred = model(data)[mask].max(1)[1]
- acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
- return acc
-
- def train(model, data, args):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
-
- parameters = model.state_dict()
- best_acc = 0.
- for epoch in range(args.epoch):
- model.train()
- optimizer.zero_grad()
- output = model(data)
- loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
- loss.backward()
- optimizer.step()
-
- val_acc = test(model, data, data.val_mask)
- if val_acc > best_acc:
- best_acc = val_acc
- parameters = model.state_dict()
-
- model.load_state_dict(parameters)
- return model
-
- if __name__ == '__main__':
-
- import argparse
- parser = argparse.ArgumentParser('pyg model')
- parser.add_argument('--device', type=str, default='cuda')
- parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
- parser.add_argument('--repeat', type=int, default=50)
- parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
- parser.add_argument('--lr', type=float, default=0.01)
- parser.add_argument('--weight_decay', type=float, default=0.0)
- parser.add_argument('--epoch', type=int, default=200)
-
- args = parser.parse_args()
-
- # seed = 100
- dataset = Planetoid(os.path.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures())
- data = dataset[0].to(args.device)
-
- accs = []
-
- for seed in tqdm(range(args.repeat)):
- np.random.seed(seed)
- torch.manual_seed(seed)
-
- if args.model == 'gat':
- model = GAT(dataset.num_node_features, dataset.num_classes)
- elif args.model == 'gcn':
- model = GCN(dataset.num_node_features, dataset.num_classes)
- elif args.model == 'sage':
- model = SAGE(dataset.num_node_features, 64, 2, dataset.num_classes)
-
- model.to(args.device)
-
- train(model, data, args)
- acc = test(model, data, data.test_mask)
- accs.append(acc)
- print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))
|