|
- """
- Performance check of DGL original dataset, model, trainer setting
-
- Borrowed from DGL official examples: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin
-
- TopkPool is not supported currently
- """
-
- # from dgl.dataloading.pytorch.dataloader import GraphDataLoader
- import pickle
- from dgl.dataloading import GraphDataLoader
- import numpy as np
- from tqdm import tqdm
-
- import random
-
- import torch
- import torch.nn as nn
- import torch.optim as optim
-
- from dgl.data import GINDataset
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from dgl.nn.pytorch.conv import GINConv
- from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
-
- def set_seed(seed=None):
- """
- Set seed of whole process
- """
- if seed is None:
- seed = random.randint(0, 5000)
-
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- class DatasetAbstraction():
- def __init__(self, graphs, labels):
- for g in graphs:
- g.ndata['feat'] = g.ndata['attr']
- self.graphs, self.labels = [], []
- for g, l in zip(graphs, labels):
- self.graphs.append(g)
- self.labels.append(l)
- self.gclasses = max(self.labels).item() + 1
- self.graph = self.graphs
-
- def __len__(self):
- return len(self.graphs)
-
- def __getitem__(self, idx):
- if isinstance(idx, int):
- return self.graphs[idx], self.labels[idx]
- elif isinstance(idx, torch.BoolTensor):
- idx = [i for i in range(len(idx)) if idx[i]]
- elif isinstance(idx, torch.Tensor) and idx.unique()[0].sum().item() == 1:
- idx = [i for i in range(len(idx)) if idx[i]]
- return DatasetAbstraction([self.graphs[i] for i in idx], [self.labels[i] for i in idx])
-
- class ApplyNodeFunc(nn.Module):
- """Update the node feature hv with MLP, BN and ReLU."""
- def __init__(self, mlp):
- super(ApplyNodeFunc, self).__init__()
- self.mlp = mlp
- self.bn = nn.BatchNorm1d(self.mlp.output_dim)
-
- def forward(self, h):
- h = self.mlp(h)
- h = self.bn(h)
- h = F.relu(h)
- return h
-
-
- class MLP(nn.Module):
- """MLP with linear output"""
- def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
- """MLP layers construction
- Paramters
- ---------
- num_layers: int
- The number of linear layers
- input_dim: int
- The dimensionality of input features
- hidden_dim: int
- The dimensionality of hidden units at ALL layers
- output_dim: int
- The number of classes for prediction
- """
- super(MLP, self).__init__()
- self.linear_or_not = True # default is linear model
- self.num_layers = num_layers
- self.output_dim = output_dim
-
- if num_layers < 1:
- raise ValueError("number of layers should be positive!")
- elif num_layers == 1:
- # Linear model
- self.linear = nn.Linear(input_dim, output_dim)
- else:
- # Multi-layer model
- self.linear_or_not = False
- self.linears = torch.nn.ModuleList()
- self.batch_norms = torch.nn.ModuleList()
-
- self.linears.append(nn.Linear(input_dim, hidden_dim))
- for layer in range(num_layers - 2):
- self.linears.append(nn.Linear(hidden_dim, hidden_dim))
- self.linears.append(nn.Linear(hidden_dim, output_dim))
-
- for layer in range(num_layers - 1):
- self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
-
- def forward(self, x):
- if self.linear_or_not:
- # If linear model
- return self.linear(x)
- else:
- # If MLP
- h = x
- for i in range(self.num_layers - 1):
- h = F.relu(self.batch_norms[i](self.linears[i](h)))
- return self.linears[-1](h)
-
-
- class GIN(nn.Module):
- """GIN model"""
- def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
- output_dim, final_dropout, learn_eps, graph_pooling_type,
- neighbor_pooling_type):
- """model parameters setting
- Paramters
- ---------
- num_layers: int
- The number of linear layers in the neural network
- num_mlp_layers: int
- The number of linear layers in mlps
- input_dim: int
- The dimensionality of input features
- hidden_dim: int
- The dimensionality of hidden units at ALL layers
- output_dim: int
- The number of classes for prediction
- final_dropout: float
- dropout ratio on the final linear layer
- learn_eps: boolean
- If True, learn epsilon to distinguish center nodes from neighbors
- If False, aggregate neighbors and center nodes altogether.
- neighbor_pooling_type: str
- how to aggregate neighbors (sum, mean, or max)
- graph_pooling_type: str
- how to aggregate entire nodes in a graph (sum, mean or max)
- """
- super(GIN, self).__init__()
- self.num_layers = num_layers
- self.learn_eps = learn_eps
-
- # List of MLPs
- self.ginlayers = torch.nn.ModuleList()
- self.batch_norms = torch.nn.ModuleList()
-
- for layer in range(self.num_layers - 1):
- if layer == 0:
- mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
- else:
- mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
-
- self.ginlayers.append(
- GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
- self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
-
- # Linear function for graph poolings of output of each layer
- # which maps the output of different layers into a prediction score
- self.linears_prediction = torch.nn.ModuleList()
-
- for layer in range(num_layers):
- if layer == 0:
- self.linears_prediction.append(
- nn.Linear(input_dim, output_dim))
- else:
- self.linears_prediction.append(
- nn.Linear(hidden_dim, output_dim))
-
- self.drop = nn.Dropout(final_dropout)
-
- if graph_pooling_type == 'sum':
- self.pool = SumPooling()
- elif graph_pooling_type == 'mean':
- self.pool = AvgPooling()
- elif graph_pooling_type == 'max':
- self.pool = MaxPooling()
- else:
- raise NotImplementedError
-
- def forward(self, g, h):
- # list of hidden representation at each layer (including input)
- hidden_rep = [h]
-
- for i in range(self.num_layers - 1):
- h = self.ginlayers[i](g, h)
- h = self.batch_norms[i](h)
- h = F.relu(h)
- hidden_rep.append(h)
-
- score_over_layer = 0
-
- # perform pooling over all nodes in each graph in every layer
- for i, h in enumerate(hidden_rep):
- pooled_h = self.pool(g, h)
- score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
-
- return score_over_layer
-
-
- def train(net, trainloader, validloader, optimizer, criterion, epoch, device):
- best_model = pickle.dumps(net.state_dict())
-
- best_acc = 0.
- for e in range(epoch):
- net.train()
- for graphs, labels in trainloader:
-
- labels = labels.to(device)
- graphs = graphs.to(device)
- feat = graphs.ndata.pop('attr')
- outputs = net(graphs, feat)
-
- loss = criterion(outputs, labels)
-
- # backprop
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
-
- gt = []
- pr = []
- net.eval()
- for graphs, labels in validloader:
- labels = labels.to(device)
- graphs = graphs.to(device)
- gt.append(labels)
- feat = graphs.ndata.pop('attr')
- outputs = net(graphs, feat)
- pr.append(outputs.argmax(1))
- gt = torch.cat(gt, dim=0)
- pr = torch.cat(pr, dim=0)
- acc = (gt == pr).float().mean().item()
- if acc > best_acc:
- best_acc = acc
- best_model = pickle.dumps(net.state_dict())
-
- net.load_state_dict(pickle.loads(best_model))
-
- return net
-
- def eval_net(net, dataloader, device):
- net.eval()
-
- total = 0
- total_correct = 0
-
- for data in dataloader:
- graphs, labels = data
- graphs = graphs.to(device)
- labels = labels.to(device)
- feat = graphs.ndata.pop('attr')
- total += len(labels)
- outputs = net(graphs, feat)
- _, predicted = torch.max(outputs.data, 1)
-
- total_correct += (predicted == labels.data).sum().item()
-
- acc = 1.0 * total_correct / total
-
- net.train()
-
- return acc
-
-
- def main():
-
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--repeat", type=int, default=10)
- parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
-
- args = parser.parse_args()
-
- device = torch.device('cuda')
- dataset_ = GINDataset(args.dataset, False)
- dataset = DatasetAbstraction([g[0] for g in dataset_], [g[1] for g in dataset_])
-
- # 1. split dataset [fix split]
- dataids = list(range(len(dataset)))
- random.seed(2021)
- random.shuffle(dataids)
-
- fold = int(len(dataset) * 0.1)
- train_dataset = dataset[dataids[:fold * 8]]
- val_dataset = dataset[dataids[fold * 8: fold * 9]]
- test_dataset = dataset[dataids[fold * 9: ]]
-
- trainloader = GraphDataLoader(train_dataset, batch_size=32, shuffle=True)
- valloader = GraphDataLoader(val_dataset, batch_size=32, shuffle=False)
- testloader = GraphDataLoader(test_dataset, batch_size=32, shuffle=False)
-
- accs = []
- for seed in tqdm(range(args.repeat)):
- # set up seeds, args.seed supported
- set_seed(seed)
-
- model = GIN(
- 5, 2, dataset_.dim_nfeats, 64, dataset_.gclasses, 0.5, False,
- "sum", "sum").to(device)
-
- criterion = nn.CrossEntropyLoss() # defaul reduce is true
- optimizer = optim.Adam(model.parameters(), lr=0.0001)
-
- model = train(model, trainloader, valloader, optimizer, criterion, 100, device)
- acc = eval_net(model, testloader, device)
- accs.append(acc)
-
- print('{:.2f} ~ {:.2f}'.format(np.mean(accs) * 100, np.std(accs) * 100))
-
- if __name__ == '__main__':
- main()
|