|
|
|
@@ -1,4 +1,7 @@ |
|
|
|
from torch_geometric.data import DataLoader |
|
|
|
import os |
|
|
|
os.environ["AUTOGL_BACKEND"] = "pyg" |
|
|
|
|
|
|
|
from torch_geometric.data import DataLoader, Data |
|
|
|
import torch.optim as optim |
|
|
|
from tqdm import tqdm |
|
|
|
from ogb.graphproppred import Evaluator |
|
|
|
@@ -7,15 +10,311 @@ import torch |
|
|
|
import numpy as np |
|
|
|
from autogl.datasets import build_dataset_from_name |
|
|
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter |
|
|
|
from ogb_gnn import GNN |
|
|
|
from autogl.backend import DependentBackend |
|
|
|
from torch_geometric.data import Data |
|
|
|
from autogl import backend |
|
|
|
|
|
|
|
|
|
|
|
backend = DependentBackend.get_backend_name() |
|
|
|
if backend.DependentBackend.is_dgl(): |
|
|
|
feat = 'feat' |
|
|
|
label = 'label' |
|
|
|
else: |
|
|
|
feat = 'x' |
|
|
|
label = 'y' |
|
|
|
|
|
|
|
cls_criterion = torch.nn.BCEWithLogitsLoss() |
|
|
|
reg_criterion = torch.nn.MSELoss() |
|
|
|
|
|
|
|
# model |
|
|
|
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch_geometric.nn.inits import uniform |
|
|
|
|
|
|
|
from torch_scatter import scatter_mean |
|
|
|
|
|
|
|
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder |
|
|
|
from torch_geometric.utils import degree |
|
|
|
|
|
|
|
### GIN convolution along the graph structure |
|
|
|
class GINConv(MessagePassing): |
|
|
|
def __init__(self, emb_dim): |
|
|
|
''' |
|
|
|
emb_dim (int): node embedding dimensionality |
|
|
|
''' |
|
|
|
|
|
|
|
super(GINConv, self).__init__(aggr = "add") |
|
|
|
|
|
|
|
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) |
|
|
|
self.eps = torch.nn.Parameter(torch.Tensor([0])) |
|
|
|
|
|
|
|
self.bond_encoder = BondEncoder(emb_dim = emb_dim) |
|
|
|
|
|
|
|
def forward(self, x, edge_index, edge_attr): |
|
|
|
edge_embedding = self.bond_encoder(edge_attr) |
|
|
|
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) |
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
def message(self, x_j, edge_attr): |
|
|
|
return F.relu(x_j + edge_attr) |
|
|
|
|
|
|
|
def update(self, aggr_out): |
|
|
|
return aggr_out |
|
|
|
|
|
|
|
### GCN convolution along the graph structure |
|
|
|
class GCNConv(MessagePassing): |
|
|
|
def __init__(self, emb_dim): |
|
|
|
super(GCNConv, self).__init__(aggr='add') |
|
|
|
|
|
|
|
self.linear = torch.nn.Linear(emb_dim, emb_dim) |
|
|
|
self.root_emb = torch.nn.Embedding(1, emb_dim) |
|
|
|
self.bond_encoder = BondEncoder(emb_dim = emb_dim) |
|
|
|
|
|
|
|
def forward(self, x, edge_index, edge_attr): |
|
|
|
x = self.linear(x) |
|
|
|
edge_embedding = self.bond_encoder(edge_attr) |
|
|
|
|
|
|
|
row, col = edge_index |
|
|
|
|
|
|
|
#edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) |
|
|
|
deg = degree(row, x.size(0), dtype = x.dtype) + 1 |
|
|
|
deg_inv_sqrt = deg.pow(-0.5) |
|
|
|
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 |
|
|
|
|
|
|
|
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] |
|
|
|
|
|
|
|
return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) |
|
|
|
|
|
|
|
def message(self, x_j, edge_attr, norm): |
|
|
|
return norm.view(-1, 1) * F.relu(x_j + edge_attr) |
|
|
|
|
|
|
|
def update(self, aggr_out): |
|
|
|
return aggr_out |
|
|
|
|
|
|
|
|
|
|
|
### GNN to generate node embedding |
|
|
|
class GNN_node(torch.nn.Module): |
|
|
|
""" |
|
|
|
Output: |
|
|
|
node representations |
|
|
|
""" |
|
|
|
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): |
|
|
|
''' |
|
|
|
emb_dim (int): node embedding dimensionality |
|
|
|
num_layer (int): number of GNN message passing layers |
|
|
|
|
|
|
|
''' |
|
|
|
|
|
|
|
super(GNN_node, self).__init__() |
|
|
|
self.num_layer = num_layer |
|
|
|
self.drop_ratio = drop_ratio |
|
|
|
self.JK = JK |
|
|
|
### add residual connection or not |
|
|
|
self.residual = residual |
|
|
|
|
|
|
|
if self.num_layer < 2: |
|
|
|
raise ValueError("Number of GNN layers must be greater than 1.") |
|
|
|
|
|
|
|
self.atom_encoder = AtomEncoder(emb_dim) |
|
|
|
|
|
|
|
###List of GNNs |
|
|
|
self.convs = torch.nn.ModuleList() |
|
|
|
self.batch_norms = torch.nn.ModuleList() |
|
|
|
|
|
|
|
for layer in range(num_layer): |
|
|
|
if gnn_type == 'gin': |
|
|
|
self.convs.append(GINConv(emb_dim)) |
|
|
|
elif gnn_type == 'gcn': |
|
|
|
self.convs.append(GCNConv(emb_dim)) |
|
|
|
else: |
|
|
|
raise ValueError('Undefined GNN type called {}'.format(gnn_type)) |
|
|
|
|
|
|
|
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) |
|
|
|
|
|
|
|
def forward(self, batched_data): |
|
|
|
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch |
|
|
|
|
|
|
|
### computing input node embedding |
|
|
|
|
|
|
|
h_list = [self.atom_encoder(x)] |
|
|
|
for layer in range(self.num_layer): |
|
|
|
|
|
|
|
h = self.convs[layer](h_list[layer], edge_index, edge_attr) |
|
|
|
h = self.batch_norms[layer](h) |
|
|
|
|
|
|
|
if layer == self.num_layer - 1: |
|
|
|
#remove relu for the last layer |
|
|
|
h = F.dropout(h, self.drop_ratio, training = self.training) |
|
|
|
else: |
|
|
|
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) |
|
|
|
|
|
|
|
if self.residual: |
|
|
|
h += h_list[layer] |
|
|
|
|
|
|
|
h_list.append(h) |
|
|
|
|
|
|
|
### Different implementations of Jk-concat |
|
|
|
if self.JK == "last": |
|
|
|
node_representation = h_list[-1] |
|
|
|
elif self.JK == "sum": |
|
|
|
node_representation = 0 |
|
|
|
for layer in range(self.num_layer + 1): |
|
|
|
node_representation += h_list[layer] |
|
|
|
|
|
|
|
return node_representation |
|
|
|
|
|
|
|
|
|
|
|
### Virtual GNN to generate node embedding |
|
|
|
class GNN_node_Virtualnode(torch.nn.Module): |
|
|
|
""" |
|
|
|
Output: |
|
|
|
node representations |
|
|
|
""" |
|
|
|
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): |
|
|
|
''' |
|
|
|
emb_dim (int): node embedding dimensionality |
|
|
|
''' |
|
|
|
|
|
|
|
super(GNN_node_Virtualnode, self).__init__() |
|
|
|
self.num_layer = num_layer |
|
|
|
self.drop_ratio = drop_ratio |
|
|
|
self.JK = JK |
|
|
|
### add residual connection or not |
|
|
|
self.residual = residual |
|
|
|
|
|
|
|
if self.num_layer < 2: |
|
|
|
raise ValueError("Number of GNN layers must be greater than 1.") |
|
|
|
|
|
|
|
self.atom_encoder = AtomEncoder(emb_dim) |
|
|
|
|
|
|
|
### set the initial virtual node embedding to 0. |
|
|
|
self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) |
|
|
|
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) |
|
|
|
|
|
|
|
### List of GNNs |
|
|
|
self.convs = torch.nn.ModuleList() |
|
|
|
### batch norms applied to node embeddings |
|
|
|
self.batch_norms = torch.nn.ModuleList() |
|
|
|
|
|
|
|
### List of MLPs to transform virtual node at every layer |
|
|
|
self.mlp_virtualnode_list = torch.nn.ModuleList() |
|
|
|
|
|
|
|
for layer in range(num_layer): |
|
|
|
if gnn_type == 'gin': |
|
|
|
self.convs.append(GINConv(emb_dim)) |
|
|
|
elif gnn_type == 'gcn': |
|
|
|
self.convs.append(GCNConv(emb_dim)) |
|
|
|
else: |
|
|
|
raise ValueError('Undefined GNN type called {}'.format(gnn_type)) |
|
|
|
|
|
|
|
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) |
|
|
|
|
|
|
|
for layer in range(num_layer - 1): |
|
|
|
self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \ |
|
|
|
torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU())) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, batched_data): |
|
|
|
|
|
|
|
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch |
|
|
|
|
|
|
|
### virtual node embeddings for graphs |
|
|
|
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) |
|
|
|
|
|
|
|
h_list = [self.atom_encoder(x)] |
|
|
|
for layer in range(self.num_layer): |
|
|
|
### add message from virtual nodes to graph nodes |
|
|
|
h_list[layer] = h_list[layer] + virtualnode_embedding[batch] |
|
|
|
|
|
|
|
### Message passing among graph nodes |
|
|
|
h = self.convs[layer](h_list[layer], edge_index, edge_attr) |
|
|
|
|
|
|
|
h = self.batch_norms[layer](h) |
|
|
|
if layer == self.num_layer - 1: |
|
|
|
#remove relu for the last layer |
|
|
|
h = F.dropout(h, self.drop_ratio, training = self.training) |
|
|
|
else: |
|
|
|
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) |
|
|
|
|
|
|
|
if self.residual: |
|
|
|
h = h + h_list[layer] |
|
|
|
|
|
|
|
h_list.append(h) |
|
|
|
|
|
|
|
### update the virtual nodes |
|
|
|
if layer < self.num_layer - 1: |
|
|
|
### add message from graph nodes to virtual nodes |
|
|
|
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding |
|
|
|
### transform virtual nodes using MLP |
|
|
|
|
|
|
|
if self.residual: |
|
|
|
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) |
|
|
|
else: |
|
|
|
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) |
|
|
|
|
|
|
|
### Different implementations of Jk-concat |
|
|
|
if self.JK == "last": |
|
|
|
node_representation = h_list[-1] |
|
|
|
elif self.JK == "sum": |
|
|
|
node_representation = 0 |
|
|
|
for layer in range(self.num_layer + 1): |
|
|
|
node_representation += h_list[layer] |
|
|
|
|
|
|
|
return node_representation |
|
|
|
|
|
|
|
class GNN(torch.nn.Module): |
|
|
|
|
|
|
|
def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, |
|
|
|
gnn_type = 'gin', virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): |
|
|
|
''' |
|
|
|
num_tasks (int): number of labels to be predicted |
|
|
|
virtual_node (bool): whether to add virtual node or not |
|
|
|
''' |
|
|
|
|
|
|
|
super(GNN, self).__init__() |
|
|
|
|
|
|
|
self.num_layer = num_layer |
|
|
|
self.drop_ratio = drop_ratio |
|
|
|
self.JK = JK |
|
|
|
self.emb_dim = emb_dim |
|
|
|
self.num_tasks = num_tasks |
|
|
|
self.graph_pooling = graph_pooling |
|
|
|
|
|
|
|
if self.num_layer < 2: |
|
|
|
raise ValueError("Number of GNN layers must be greater than 1.") |
|
|
|
|
|
|
|
### GNN to generate node embeddings |
|
|
|
if virtual_node: |
|
|
|
self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) |
|
|
|
else: |
|
|
|
self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) |
|
|
|
|
|
|
|
|
|
|
|
### Pooling function to generate whole-graph embeddings |
|
|
|
if self.graph_pooling == "sum": |
|
|
|
self.pool = global_add_pool |
|
|
|
elif self.graph_pooling == "mean": |
|
|
|
self.pool = global_mean_pool |
|
|
|
elif self.graph_pooling == "max": |
|
|
|
self.pool = global_max_pool |
|
|
|
elif self.graph_pooling == "attention": |
|
|
|
self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) |
|
|
|
elif self.graph_pooling == "set2set": |
|
|
|
self.pool = Set2Set(emb_dim, processing_steps = 2) |
|
|
|
else: |
|
|
|
raise ValueError("Invalid graph pooling type.") |
|
|
|
|
|
|
|
if graph_pooling == "set2set": |
|
|
|
self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) |
|
|
|
else: |
|
|
|
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) |
|
|
|
|
|
|
|
def forward(self, batched_data): |
|
|
|
h_node = self.gnn_node(batched_data) |
|
|
|
|
|
|
|
h_graph = self.pool(h_node, batched_data.batch) |
|
|
|
|
|
|
|
return self.graph_pred_linear(h_graph) |
|
|
|
|
|
|
|
|
|
|
|
def train(model, device, loader, optimizer, task_type): |
|
|
|
model.train() |
|
|
|
|
|
|
|
@@ -59,8 +358,8 @@ def eval(model, device, loader, evaluator): |
|
|
|
def trans(dataset): |
|
|
|
ret = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
x = dataset[i].nodes.data['x'] |
|
|
|
y = dataset[i].data['y'].view(-1, 1) |
|
|
|
x = dataset[i].nodes.data[feat] |
|
|
|
y = dataset[i].data[label].view(-1, 1) |
|
|
|
edge_index = dataset[i].edges.connections |
|
|
|
edge_attr = dataset[i].edges.data['edge_feat'] |
|
|
|
data = Data(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr) |
|
|
|
@@ -78,9 +377,6 @@ if __name__ == "__main__": |
|
|
|
help="graph classification dataset", |
|
|
|
choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab", "ogbg-molbace"], |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--configs", default="../configs/graphclf_gin_benchmark.yml", help="config files" |
|
|
|
) |
|
|
|
parser.add_argument("--device", type=int, default=0, help="device to run on, -1 means cpu") |
|
|
|
parser.add_argument("--seed", type=int, default=0, help="random seed") |
|
|
|
|
|
|
|
@@ -120,16 +416,15 @@ if __name__ == "__main__": |
|
|
|
valid_curve = [] |
|
|
|
test_curve = [] |
|
|
|
train_curve = [] |
|
|
|
device = torch.device("cuda:0") |
|
|
|
for epoch in range(1, 100 + 1): |
|
|
|
print("=====Epoch {}".format(epoch)) |
|
|
|
print('Training...') |
|
|
|
train(model, device, train_loader, optimizer, 'binary classification') |
|
|
|
train(model, args.device, train_loader, optimizer, 'binary classification') |
|
|
|
|
|
|
|
print('Evaluating...') |
|
|
|
train_perf = eval(model, device, train_loader, evaluator) |
|
|
|
valid_perf = eval(model, device, valid_loader, evaluator) |
|
|
|
test_perf = eval(model, device, test_loader, evaluator) |
|
|
|
train_perf = eval(model, args.device, train_loader, evaluator) |
|
|
|
valid_perf = eval(model, args.device, valid_loader, evaluator) |
|
|
|
test_perf = eval(model, args.device, test_loader, evaluator) |
|
|
|
|
|
|
|
print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf}) |
|
|
|
|
|
|
|
|