Browse Source

Merge pull request #134 from THUMNLab/graphcl_ogb

add ogb gnn
develop/0.4/predevelop
Generall GitHub 3 years ago
parent
commit
0f766ca6cf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 302 additions and 0 deletions
  1. +302
    -0
      examples/ogb_gnn.py

+ 302
- 0
examples/ogb_gnn.py View File

@@ -0,0 +1,302 @@
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_geometric.nn.inits import uniform

# from conv import GNN_node, GNN_node_Virtualnode

from torch_scatter import scatter_mean

import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
from torch_geometric.utils import degree

import math

### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''

super(GINConv, self).__init__(aggr = "add")

self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
self.eps = torch.nn.Parameter(torch.Tensor([0]))

self.bond_encoder = BondEncoder(emb_dim = emb_dim)

def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr)
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

return out

def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)

def update(self, aggr_out):
return aggr_out

### GCN convolution along the graph structure
class GCNConv(MessagePassing):
def __init__(self, emb_dim):
super(GCNConv, self).__init__(aggr='add')

self.linear = torch.nn.Linear(emb_dim, emb_dim)
self.root_emb = torch.nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim = emb_dim)

def forward(self, x, edge_index, edge_attr):
x = self.linear(x)
edge_embedding = self.bond_encoder(edge_attr)

row, col = edge_index

#edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
deg = degree(row, x.size(0), dtype = x.dtype) + 1
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

def message(self, x_j, edge_attr, norm):
return norm.view(-1, 1) * F.relu(x_j + edge_attr)

def update(self, aggr_out):
return aggr_out


### GNN to generate node embedding
class GNN_node(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
emb_dim (int): node embedding dimensionality
num_layer (int): number of GNN message passing layers

'''

super(GNN_node, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual

if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")

self.atom_encoder = AtomEncoder(emb_dim)

###List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()

for layer in range(num_layer):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
raise ValueError('Undefined GNN type called {}'.format(gnn_type))

self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

def forward(self, batched_data):
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

### computing input node embedding

h_list = [self.atom_encoder(x)]
for layer in range(self.num_layer):

h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)

if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

if self.residual:
h += h_list[layer]

h_list.append(h)

### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer + 1):
node_representation += h_list[layer]

return node_representation


### Virtual GNN to generate node embedding
class GNN_node_Virtualnode(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
emb_dim (int): node embedding dimensionality
'''

super(GNN_node_Virtualnode, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual

if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")

self.atom_encoder = AtomEncoder(emb_dim)

### set the initial virtual node embedding to 0.
self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

### List of GNNs
self.convs = torch.nn.ModuleList()
### batch norms applied to node embeddings
self.batch_norms = torch.nn.ModuleList()

### List of MLPs to transform virtual node at every layer
self.mlp_virtualnode_list = torch.nn.ModuleList()

for layer in range(num_layer):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
raise ValueError('Undefined GNN type called {}'.format(gnn_type))

self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

for layer in range(num_layer - 1):
self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \
torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU()))


def forward(self, batched_data):

x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

### virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

h_list = [self.atom_encoder(x)]
for layer in range(self.num_layer):
### add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

### Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)

h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

if self.residual:
h = h + h_list[layer]

h_list.append(h)

### update the virtual nodes
if layer < self.num_layer - 1:
### add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
### transform virtual nodes using MLP

if self.residual:
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
else:
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer + 1):
node_representation += h_list[layer]

return node_representation

class GNN(torch.nn.Module):

def __init__(self, num_tasks, num_layer = 5, emb_dim = 300,
gnn_type = 'gin', virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
'''
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
'''

super(GNN, self).__init__()

self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling

if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")

### GNN to generate node embeddings
if virtual_node:
self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
else:
self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)


### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps = 2)
else:
raise ValueError("Invalid graph pooling type.")

if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

def forward(self, batched_data):
h_node = self.gnn_node(batched_data)

h_graph = self.pool(h_node, batched_data.batch)

return self.graph_pred_linear(h_graph)

if __name__ == '__main__':
GNN(num_tasks = 10)

Loading…
Cancel
Save