""" Performance check of PYG (model + trainer + dataset) """ import os import random import numpy as np from tqdm import tqdm import torch import torch.nn.functional as F from torch.nn import Sequential, Linear, ReLU import torch_geometric from torch_geometric.datasets import TUDataset if int(torch_geometric.__version__.split(".")[0]) >= 2: from torch_geometric.loader import DataLoader else: from torch_geometric.data import DataLoader from torch_geometric.nn import GINConv, global_add_pool, GraphConv, TopKPooling from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp import logging torch.backends.cudnn.deterministic = True #torch.use_deterministic_algorithms(True) logging.basicConfig(level=logging.ERROR) class GIN(torch.nn.Module): def __init__(self): super(GIN, self).__init__() num_features = dataset.num_features dim = 32 nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) self.conv1 = GINConv(nn1) self.bn1 = torch.nn.BatchNorm1d(dim) nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv2 = GINConv(nn2) self.bn2 = torch.nn.BatchNorm1d(dim) nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv3 = GINConv(nn3) self.bn3 = torch.nn.BatchNorm1d(dim) nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv4 = GINConv(nn4) self.bn4 = torch.nn.BatchNorm1d(dim) nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) self.conv5 = GINConv(nn5) self.bn5 = torch.nn.BatchNorm1d(dim) self.fc1 = Linear(dim, dim) self.fc2 = Linear(dim, dataset.num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x = self.bn1(x) x = F.relu(self.conv2(x, edge_index)) x = self.bn2(x) x = F.relu(self.conv3(x, edge_index)) x = self.bn3(x) x = F.relu(self.conv4(x, edge_index)) x = self.bn4(x) x = F.relu(self.conv5(x, edge_index)) x = self.bn5(x) x = global_add_pool(x, batch) x = F.relu(self.fc1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=-1) class TopKPool(torch.nn.Module): def __init__(self): super(TopKPool, self).__init__() self.conv1 = GraphConv(dataset.num_features, 128) self.pool1 = TopKPooling(128, ratio=0.8) self.conv2 = GraphConv(128, 128) self.pool2 = TopKPooling(128, ratio=0.8) self.conv3 = GraphConv(128, 128) self.pool3 = TopKPooling(128, ratio=0.8) self.lin1 = torch.nn.Linear(256, 128) self.lin2 = torch.nn.Linear(128, 64) self.lin3 = torch.nn.Linear(64, dataset.num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv2(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch) x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = F.relu(self.conv3(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch) x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) x = x1 + x2 + x3 x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = F.relu(self.lin2(x)) x = F.log_softmax(self.lin3(x), dim=-1) return x def test(model, loader, args): model.eval() correct = 0 for data in loader: data = data.to(args.device) output = model(data) pred = output.max(dim=1)[1] correct += pred.eq(data.y).sum().item() return correct / len(loader.dataset) def train(model, train_loader, val_loader, args): optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) parameters = model.state_dict() best_acc = 0. for epoch in range(args.epoch): model.train() for data in train_loader: data = data.to(args.device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, data.y) loss.backward() optimizer.step() val_acc = test(model, val_loader, args) if val_acc > best_acc: best_acc = val_acc parameters = model.state_dict() model.load_state_dict(parameters) return model if __name__ == '__main__': import argparse parser = argparse.ArgumentParser('pyg trainer') parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG') parser.add_argument('--dataset_seed', type=int, default=2021) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--repeat', type=int, default=50) parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin') parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--epoch', type=int, default=100) args = parser.parse_args() # seed = 100 dataset = TUDataset(os.path.expanduser('~/.pyg'), args.dataset) # 1. split dataset [fix split] dataids = list(range(len(dataset))) random.seed(args.dataset_seed) random.shuffle(dataids) torch.manual_seed(args.dataset_seed) np.random.seed(args.dataset_seed) if args.device == 'cuda': torch.cuda.manual_seed(args.dataset_seed) fold = int(len(dataset) * 0.1) train_index = dataids[:fold * 8] val_index = dataids[fold * 8: fold * 9] test_index = dataids[fold * 9: ] dataset.train_index = train_index dataset.val_index = val_index dataset.test_index = test_index dataset.train_split = dataset[dataset.train_index] dataset.val_split = dataset[dataset.val_index] dataset.test_split = dataset[dataset.test_index] labels = np.array([data.y.item() for data in dataset.test_split]) def seed_worker(worker_id): #seed = torch.initial_seed() torch.manual_seed(args.dataset_seed) np.random.seed(args.dataset_seed) random.seed(args.dataset_seed) g = torch.Generator() g.manual_seed(args.dataset_seed) train_loader = DataLoader(dataset.train_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g) val_loader = DataLoader(dataset.val_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g) test_loader = DataLoader(dataset.test_split, batch_size=args.batch_size, worker_init_fn=seed_worker, generator=g) #train_loader = DataLoader(dataset.train_split, batch_size=args.batch_size, shuffle=False) #val_loader = DataLoader(dataset.val_split, batch_size=args.batch_size, shuffle=False) #test_loader = DataLoader(dataset.test_split, batch_size=args.batch_size, shuffle=False) accs = [] for seed in tqdm(range(args.repeat)): torch.manual_seed(seed) np.random.seed(seed) #random.seed(seed) if args.device == 'cuda': torch.cuda.manual_seed(seed) if args.model == 'gin': model = GIN() elif args.model == 'topkpool': model = TopKPool() model.to(args.device) train(model, train_loader, val_loader, args) acc = test(model, test_loader, args) accs.append(acc) print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))