import os os.environ["AUTOGL_BACKEND"] = "pyg" from torch_geometric.data import DataLoader, Data import torch.optim as optim from tqdm import tqdm from ogb.graphproppred import Evaluator import random import torch import numpy as np from autogl.datasets import build_dataset_from_name from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from autogl import backend if backend.DependentBackend.is_dgl(): feat = 'feat' label = 'label' else: feat = 'x' label = 'y' cls_criterion = torch.nn.BCEWithLogitsLoss() reg_criterion = torch.nn.MSELoss() # model from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set import torch.nn.functional as F from torch_geometric.nn.inits import uniform from torch_scatter import scatter_mean from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder from torch_geometric.utils import degree ### GIN convolution along the graph structure class GINConv(MessagePassing): def __init__(self, emb_dim): ''' emb_dim (int): node embedding dimensionality ''' super(GINConv, self).__init__(aggr = "add") self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim = emb_dim) def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) return out def message(self, x_j, edge_attr): return F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out ### GCN convolution along the graph structure class GCNConv(MessagePassing): def __init__(self, emb_dim): super(GCNConv, self).__init__(aggr='add') self.linear = torch.nn.Linear(emb_dim, emb_dim) self.root_emb = torch.nn.Embedding(1, emb_dim) self.bond_encoder = BondEncoder(emb_dim = emb_dim) def forward(self, x, edge_index, edge_attr): x = self.linear(x) edge_embedding = self.bond_encoder(edge_attr) row, col = edge_index #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) deg = degree(row, x.size(0), dtype = x.dtype) + 1 deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) def message(self, x_j, edge_attr, norm): return norm.view(-1, 1) * F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out ### GNN to generate node embedding class GNN_node(torch.nn.Module): """ Output: node representations """ def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): ''' emb_dim (int): node embedding dimensionality num_layer (int): number of GNN message passing layers ''' super(GNN_node, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK ### add residual connection or not self.residual = residual if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(emb_dim) ###List of GNNs self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for layer in range(num_layer): if gnn_type == 'gin': self.convs.append(GINConv(emb_dim)) elif gnn_type == 'gcn': self.convs.append(GCNConv(emb_dim)) else: raise ValueError('Undefined GNN type called {}'.format(gnn_type)) self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch ### computing input node embedding h_list = [self.atom_encoder(x)] for layer in range(self.num_layer): h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h += h_list[layer] h_list.append(h) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer + 1): node_representation += h_list[layer] return node_representation ### Virtual GNN to generate node embedding class GNN_node_Virtualnode(torch.nn.Module): """ Output: node representations """ def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'): ''' emb_dim (int): node embedding dimensionality ''' super(GNN_node_Virtualnode, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK ### add residual connection or not self.residual = residual if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(emb_dim) ### set the initial virtual node embedding to 0. self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) ### List of GNNs self.convs = torch.nn.ModuleList() ### batch norms applied to node embeddings self.batch_norms = torch.nn.ModuleList() ### List of MLPs to transform virtual node at every layer self.mlp_virtualnode_list = torch.nn.ModuleList() for layer in range(num_layer): if gnn_type == 'gin': self.convs.append(GINConv(emb_dim)) elif gnn_type == 'gcn': self.convs.append(GCNConv(emb_dim)) else: raise ValueError('Undefined GNN type called {}'.format(gnn_type)) self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) for layer in range(num_layer - 1): self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \ torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU())) def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: #remove relu for the last layer h = F.dropout(h, self.drop_ratio, training = self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) if self.residual: h = h + h_list[layer] h_list.append(h) ### update the virtual nodes if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding ### transform virtual nodes using MLP if self.residual: virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) else: virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training) ### Different implementations of Jk-concat if self.JK == "last": node_representation = h_list[-1] elif self.JK == "sum": node_representation = 0 for layer in range(self.num_layer + 1): node_representation += h_list[layer] return node_representation class GNN(torch.nn.Module): def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, gnn_type = 'gin', virtual_node = False, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"): ''' num_tasks (int): number of labels to be predicted virtual_node (bool): whether to add virtual node or not ''' super(GNN, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio self.JK = JK self.emb_dim = emb_dim self.num_tasks = num_tasks self.graph_pooling = graph_pooling if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") ### GNN to generate node embeddings if virtual_node: self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) else: self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type) ### Pooling function to generate whole-graph embeddings if self.graph_pooling == "sum": self.pool = global_add_pool elif self.graph_pooling == "mean": self.pool = global_mean_pool elif self.graph_pooling == "max": self.pool = global_max_pool elif self.graph_pooling == "attention": self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1))) elif self.graph_pooling == "set2set": self.pool = Set2Set(emb_dim, processing_steps = 2) else: raise ValueError("Invalid graph pooling type.") if graph_pooling == "set2set": self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks) else: self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) def forward(self, batched_data): h_node = self.gnn_node(batched_data) h_graph = self.pool(h_node, batched_data.batch) return self.graph_pred_linear(h_graph) def train(model, device, loader, optimizer, task_type): model.train() for step, batch in enumerate(tqdm(loader, desc="Iteration")): batch = batch.to(device) if batch.x.shape[0] == 1 or batch.batch[-1] == 0: pass else: pred = model(batch) optimizer.zero_grad() is_labeled = batch.y == batch.y loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled]) loss.backward() optimizer.step() def eval(model, device, loader, evaluator): model.eval() y_true = [] y_pred = [] for step, batch in enumerate(tqdm(loader, desc="Iteration")): batch = batch.to(device) if batch.x.shape[0] == 1: pass else: with torch.no_grad(): pred = model(batch) y_true.append(batch.y.view(pred.shape).detach().cpu()) y_pred.append(pred.detach().cpu()) y_true = torch.cat(y_true, dim = 0).numpy() y_pred = torch.cat(y_pred, dim = 0).numpy() input_dict = {"y_true": y_true, "y_pred": y_pred} return evaluator.eval(input_dict) def trans(dataset): ret = [] for i in range(len(dataset)): x = dataset[i].nodes.data[feat] y = dataset[i].data[label].view(-1, 1) edge_index = dataset[i].edges.connections edge_attr = dataset[i].edges.data['edge_feat'] data = Data(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr) ret.append(data) return ret if __name__ == "__main__": parser = ArgumentParser( "auto graph classification", formatter_class=ArgumentDefaultsHelpFormatter ) parser.add_argument( "--dataset", default="ogbg-molhiv", type=str, help="graph classification dataset", choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab", "ogbg-molbace"], ) parser.add_argument("--device", type=int, default=0, help="device to run on, -1 means cpu") parser.add_argument("--seed", type=int, default=0, help="random seed") args = parser.parse_args() if args.device == -1: args.device = "cpu" if torch.cuda.is_available() and args.device != "cpu": torch.cuda.set_device(args.device) seed = args.seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False dataset = build_dataset_from_name(args.dataset) model = GNN(num_tasks=1, gnn_type = 'gcn').to(args.device) evaluator = Evaluator(args.dataset) train_dataset = trans(dataset.train_split) val_dataset = trans(dataset.val_split) test_dataset = trans(dataset.test_split) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0) valid_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0) optimizer = optim.Adam(model.parameters(), lr=0.001) valid_curve = [] test_curve = [] train_curve = [] for epoch in range(1, 100 + 1): print("=====Epoch {}".format(epoch)) print('Training...') train(model, args.device, train_loader, optimizer, 'binary classification') print('Evaluating...') train_perf = eval(model, args.device, train_loader, evaluator) valid_perf = eval(model, args.device, valid_loader, evaluator) test_perf = eval(model, args.device, test_loader, evaluator) print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf}) train_curve.append(train_perf['rocauc']) valid_curve.append(valid_perf['rocauc']) test_curve.append(test_perf['rocauc']) best_val_epoch = np.argmax(np.array(valid_curve)) best_train = max(train_curve) print('Finished training!') print('Best validation score: {}'.format(valid_curve[best_val_epoch])) print('Test score: {}'.format(test_curve[best_val_epoch]))