|
- import os
- os.environ["AUTOGL_BACKEND"] = "pyg"
-
- from torch_geometric.data import DataLoader, Data
- import torch.optim as optim
- from tqdm import tqdm
- from ogb.graphproppred import Evaluator
- import random
- import torch
- import numpy as np
- from autogl.datasets import build_dataset_from_name
- from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
- from autogl import backend
-
-
- if backend.DependentBackend.is_dgl():
- feat = 'feat'
- label = 'label'
- else:
- feat = 'x'
- label = 'y'
-
- cls_criterion = torch.nn.BCEWithLogitsLoss()
- reg_criterion = torch.nn.MSELoss()
-
- # model
- from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
- import torch.nn.functional as F
- from torch_geometric.nn.inits import uniform
-
- from torch_scatter import scatter_mean
-
- from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
- from torch_geometric.utils import degree
-
- ### GIN convolution along the graph structure
- class GINConv(MessagePassing):
- def __init__(self, emb_dim):
- '''
- emb_dim (int): node embedding dimensionality
- '''
-
- super(GINConv, self).__init__(aggr = "add")
-
- self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
- self.eps = torch.nn.Parameter(torch.Tensor([0]))
-
- self.bond_encoder = BondEncoder(emb_dim = emb_dim)
-
- def forward(self, x, edge_index, edge_attr):
- edge_embedding = self.bond_encoder(edge_attr)
- out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
-
- return out
-
- def message(self, x_j, edge_attr):
- return F.relu(x_j + edge_attr)
-
- def update(self, aggr_out):
- return aggr_out
-
- ### GCN convolution along the graph structure
- class GCNConv(MessagePassing):
- def __init__(self, emb_dim):
- super(GCNConv, self).__init__(aggr='add')
-
- self.linear = torch.nn.Linear(emb_dim, emb_dim)
- self.root_emb = torch.nn.Embedding(1, emb_dim)
- self.bond_encoder = BondEncoder(emb_dim = emb_dim)
-
- def forward(self, x, edge_index, edge_attr):
- x = self.linear(x)
- edge_embedding = self.bond_encoder(edge_attr)
-
- row, col = edge_index
-
- #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
- deg = degree(row, x.size(0), dtype = x.dtype) + 1
- deg_inv_sqrt = deg.pow(-0.5)
- deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
-
- norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
-
- return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
-
- def message(self, x_j, edge_attr, norm):
- return norm.view(-1, 1) * F.relu(x_j + edge_attr)
-
- def update(self, aggr_out):
- return aggr_out
-
-
- ### GNN to generate node embedding
- class GNN_node(torch.nn.Module):
- """
- Output:
- node representations
- """
- def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
- '''
- emb_dim (int): node embedding dimensionality
- num_layer (int): number of GNN message passing layers
-
- '''
-
- super(GNN_node, self).__init__()
- self.num_layer = num_layer
- self.drop_ratio = drop_ratio
- self.JK = JK
- ### add residual connection or not
- self.residual = residual
-
- if self.num_layer < 2:
- raise ValueError("Number of GNN layers must be greater than 1.")
-
- self.atom_encoder = AtomEncoder(emb_dim)
-
- ###List of GNNs
- self.convs = torch.nn.ModuleList()
- self.batch_norms = torch.nn.ModuleList()
-
- for layer in range(num_layer):
- if gnn_type == 'gin':
- self.convs.append(GINConv(emb_dim))
- elif gnn_type == 'gcn':
- self.convs.append(GCNConv(emb_dim))
- else:
- raise ValueError('Undefined GNN type called {}'.format(gnn_type))
-
- self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
-
- def forward(self, batched_data):
- x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
-
- ### computing input node embedding
-
- h_list = [self.atom_encoder(x)]
- for layer in range(self.num_layer):
-
- h = self.convs[layer](h_list[layer], edge_index, edge_attr)
- h = self.batch_norms[layer](h)
-
- if layer == self.num_layer - 1:
- #remove relu for the last layer
- h = F.dropout(h, self.drop_ratio, training = self.training)
- else:
- h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
-
- if self.residual:
- h += h_list[layer]
-
- h_list.append(h)
-
- ### Different implementations of Jk-concat
- if self.JK == "last":
- node_representation = h_list[-1]
- elif self.JK == "sum":
- node_representation = 0
- for layer in range(self.num_layer + 1):
- node_representation += h_list[layer]
-
- return node_representation
-
-
- ### Virtual GNN to generate node embedding
- class GNN_node_Virtualnode(torch.nn.Module):
- """
- Output:
- node representations
- """
- def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
- '''
- emb_dim (int): node embedding dimensionality
- '''
-
- super(GNN_node_Virtualnode, self).__init__()
- self.num_layer = num_layer
- self.drop_ratio = drop_ratio
- self.JK = JK
- ### add residual connection or not
- self.residual = residual
-
- if self.num_layer < 2:
- raise ValueError("Number of GNN layers must be greater than 1.")
-
- self.atom_encoder = AtomEncoder(emb_dim)
-
- ### set the initial virtual node embedding to 0.
- self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
- torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
-
- ### List of GNNs
- self.convs = torch.nn.ModuleList()
- ### batch norms applied to node embeddings
- self.batch_norms = torch.nn.ModuleList()
-
- ### List of MLPs to transform virtual node at every layer
- self.mlp_virtualnode_list = torch.nn.ModuleList()
-
- for layer in range(num_layer):
- if gnn_type == 'gin':
- self.convs.append(GINConv(emb_dim))
- elif gnn_type == 'gcn':
- self.convs.append(GCNConv(emb_dim))
- else:
- raise ValueError('Undefined GNN type called {}'.format(gnn_type))
-
- self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
-
- for layer in range(num_layer - 1):
- self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \
- torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU()))
-
-
- def forward(self, batched_data):
-
- x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
-
- ### virtual node embeddings for graphs
- virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
-
- h_list = [self.atom_encoder(x)]
- for layer in range(self.num_layer):
- ### add message from virtual nodes to graph nodes
- h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
-
- ### Message passing among graph nodes
- h = self.convs[layer](h_list[layer], edge_index, edge_attr)
-
- h = self.batch_norms[layer](h)
- if layer == self.num_layer - 1:
- #remove relu for the last layer
- h = F.dropout(h, self.drop_ratio, training = self.training)
- else:
- h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
-
- if self.residual:
- h = h + h_list[layer]
-
- h_list.append(h)
-
- ### update the virtual nodes
- if layer < self.num_layer - 1:
- ### add message from graph nodes to virtual nodes
- virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
- ### transform virtual nodes using MLP
-
- if self.residual:
- virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
- else:
- virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
-
- ### Different implementations of Jk-concat
- if self.JK == "last":
- node_representation = h_list[-1]
- elif self.JK == "sum":
- node_representation = 0
- for layer in range(self.num_layer + 1):
- node_representation += h_list[layer]
-
- return node_representation
-
- class GNN(torch.nn.Module):
-
- def __init__(self, num_tasks, num_layer = 5, emb_dim = 300,
- gnn_type = 'gin', virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
- '''
- num_tasks (int): number of labels to be predicted
- virtual_node (bool): whether to add virtual node or not
- '''
-
- super(GNN, self).__init__()
-
- self.num_layer = num_layer
- self.drop_ratio = drop_ratio
- self.JK = JK
- self.emb_dim = emb_dim
- self.num_tasks = num_tasks
- self.graph_pooling = graph_pooling
-
- if self.num_layer < 2:
- raise ValueError("Number of GNN layers must be greater than 1.")
-
- ### GNN to generate node embeddings
- if virtual_node:
- self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
- else:
- self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
-
-
- ### Pooling function to generate whole-graph embeddings
- if self.graph_pooling == "sum":
- self.pool = global_add_pool
- elif self.graph_pooling == "mean":
- self.pool = global_mean_pool
- elif self.graph_pooling == "max":
- self.pool = global_max_pool
- elif self.graph_pooling == "attention":
- self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
- elif self.graph_pooling == "set2set":
- self.pool = Set2Set(emb_dim, processing_steps = 2)
- else:
- raise ValueError("Invalid graph pooling type.")
-
- if graph_pooling == "set2set":
- self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
- else:
- self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
-
- def forward(self, batched_data):
- h_node = self.gnn_node(batched_data)
-
- h_graph = self.pool(h_node, batched_data.batch)
-
- return self.graph_pred_linear(h_graph)
-
-
- def train(model, device, loader, optimizer, task_type):
- model.train()
-
- for step, batch in enumerate(tqdm(loader, desc="Iteration")):
- batch = batch.to(device)
-
- if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
- pass
- else:
- pred = model(batch)
- optimizer.zero_grad()
- is_labeled = batch.y == batch.y
- loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
- loss.backward()
- optimizer.step()
-
- def eval(model, device, loader, evaluator):
- model.eval()
- y_true = []
- y_pred = []
-
- for step, batch in enumerate(tqdm(loader, desc="Iteration")):
- batch = batch.to(device)
-
- if batch.x.shape[0] == 1:
- pass
- else:
- with torch.no_grad():
- pred = model(batch)
-
- y_true.append(batch.y.view(pred.shape).detach().cpu())
- y_pred.append(pred.detach().cpu())
-
- y_true = torch.cat(y_true, dim = 0).numpy()
- y_pred = torch.cat(y_pred, dim = 0).numpy()
-
- input_dict = {"y_true": y_true, "y_pred": y_pred}
-
- return evaluator.eval(input_dict)
-
- def trans(dataset):
- ret = []
- for i in range(len(dataset)):
- x = dataset[i].nodes.data[feat]
- y = dataset[i].data[label].view(-1, 1)
- edge_index = dataset[i].edges.connections
- edge_attr = dataset[i].edges.data['edge_feat']
- data = Data(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr)
- ret.append(data)
- return ret
-
- if __name__ == "__main__":
- parser = ArgumentParser(
- "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter
- )
- parser.add_argument(
- "--dataset",
- default="ogbg-molhiv",
- type=str,
- help="graph classification dataset",
- choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab", "ogbg-molbace"],
- )
- parser.add_argument("--device", type=int, default=0, help="device to run on, -1 means cpu")
- parser.add_argument("--seed", type=int, default=0, help="random seed")
-
- args = parser.parse_args()
-
- if args.device == -1:
- args.device = "cpu"
-
- if torch.cuda.is_available() and args.device != "cpu":
- torch.cuda.set_device(args.device)
- seed = args.seed
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- dataset = build_dataset_from_name(args.dataset)
- model = GNN(num_tasks=1, gnn_type = 'gcn').to(args.device)
- evaluator = Evaluator(args.dataset)
-
- train_dataset = trans(dataset.train_split)
- val_dataset = trans(dataset.val_split)
- test_dataset = trans(dataset.test_split)
-
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
- num_workers=0)
- valid_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
- num_workers=0)
- test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
- num_workers=0)
-
- optimizer = optim.Adam(model.parameters(), lr=0.001)
-
- valid_curve = []
- test_curve = []
- train_curve = []
- for epoch in range(1, 100 + 1):
- print("=====Epoch {}".format(epoch))
- print('Training...')
- train(model, args.device, train_loader, optimizer, 'binary classification')
-
- print('Evaluating...')
- train_perf = eval(model, args.device, train_loader, evaluator)
- valid_perf = eval(model, args.device, valid_loader, evaluator)
- test_perf = eval(model, args.device, test_loader, evaluator)
-
- print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
-
- train_curve.append(train_perf['rocauc'])
- valid_curve.append(valid_perf['rocauc'])
- test_curve.append(test_perf['rocauc'])
-
- best_val_epoch = np.argmax(np.array(valid_curve))
- best_train = max(train_curve)
-
- print('Finished training!')
- print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
- print('Test score: {}'.format(test_curve[best_val_epoch]))
|