You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nodeclf_ogb_proteins.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import argparse
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. import torch_geometric.transforms as T
  6. from torch_sparse import SparseTensor
  7. from torch_geometric.nn import GCNConv, SAGEConv
  8. from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
  9. from autogl import backend
  10. from autogl.datasets import build_dataset_from_name
  11. if backend.DependentBackend.is_dgl():
  12. ylabel = 'label'
  13. else:
  14. ylabel = 'y'
  15. class GCN(torch.nn.Module):
  16. def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
  17. dropout):
  18. super(GCN, self).__init__()
  19. self.convs = torch.nn.ModuleList()
  20. self.convs.append(
  21. GCNConv(in_channels, hidden_channels, normalize=False))
  22. for _ in range(num_layers - 2):
  23. self.convs.append(
  24. GCNConv(hidden_channels, hidden_channels, normalize=False))
  25. self.convs.append(
  26. GCNConv(hidden_channels, out_channels, normalize=False))
  27. self.dropout = dropout
  28. def reset_parameters(self):
  29. for conv in self.convs:
  30. conv.reset_parameters()
  31. def forward(self, x, adj_t):
  32. for conv in self.convs[:-1]:
  33. x = conv(x, adj_t)
  34. x = F.relu(x)
  35. x = F.dropout(x, p=self.dropout, training=self.training)
  36. x = self.convs[-1](x, adj_t)
  37. return x
  38. class SAGE(torch.nn.Module):
  39. def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
  40. dropout):
  41. super(SAGE, self).__init__()
  42. self.convs = torch.nn.ModuleList()
  43. self.convs.append(SAGEConv(in_channels, hidden_channels))
  44. for _ in range(num_layers - 2):
  45. self.convs.append(SAGEConv(hidden_channels, hidden_channels))
  46. self.convs.append(SAGEConv(hidden_channels, out_channels))
  47. self.dropout = dropout
  48. def reset_parameters(self):
  49. for conv in self.convs:
  50. conv.reset_parameters()
  51. def forward(self, x, adj_t):
  52. for conv in self.convs[:-1]:
  53. x = conv(x, adj_t)
  54. x = F.relu(x)
  55. x = F.dropout(x, p=self.dropout, training=self.training)
  56. x = self.convs[-1](x, adj_t)
  57. return x
  58. def train(model, x, y, edge_index, train_idx, optimizer):
  59. model.train()
  60. criterion = torch.nn.BCEWithLogitsLoss()
  61. optimizer.zero_grad()
  62. out = model(x, edge_index)[train_idx]
  63. loss = criterion(out, y[train_idx].to(torch.float))
  64. loss.backward()
  65. optimizer.step()
  66. return loss.item()
  67. @torch.no_grad()
  68. def test(model, x, y, edge_index, split_idx, evaluator):
  69. model.eval()
  70. y_pred = model(x, edge_index)
  71. train_rocauc = evaluator.eval({
  72. 'y_true': y[split_idx['train']],
  73. 'y_pred': y_pred[split_idx['train']],
  74. })['rocauc']
  75. valid_rocauc = evaluator.eval({
  76. 'y_true': y[split_idx['valid']],
  77. 'y_pred': y_pred[split_idx['valid']],
  78. })['rocauc']
  79. test_rocauc = evaluator.eval({
  80. 'y_true': y[split_idx['test']],
  81. 'y_pred': y_pred[split_idx['test']],
  82. })['rocauc']
  83. return train_rocauc, valid_rocauc, test_rocauc
  84. def main():
  85. parser = argparse.ArgumentParser(description='OGBN-Proteins (GNN)')
  86. parser.add_argument('--device', type=int, default=0)
  87. parser.add_argument('--log_steps', type=int, default=1)
  88. parser.add_argument('--use_sage', action='store_true')
  89. parser.add_argument('--num_layers', type=int, default=3)
  90. parser.add_argument('--hidden_channels', type=int, default=256)
  91. parser.add_argument('--dropout', type=float, default=0.0)
  92. parser.add_argument('--lr', type=float, default=0.01)
  93. parser.add_argument('--epochs', type=int, default=1000)
  94. parser.add_argument('--eval_steps', type=int, default=5)
  95. parser.add_argument('--runs', type=int, default=10)
  96. args = parser.parse_args()
  97. print(args)
  98. device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
  99. device = torch.device(device)
  100. autogl_dataset = build_dataset_from_name('ogbn-proteins')
  101. data = autogl_dataset[0]
  102. y = data.nodes.data[ylabel].to(device)
  103. num_nodes = data.nodes.data['species'].shape[0]
  104. edge_index = data.edges.connections
  105. row = edge_index[0].type(torch.long).to(device)
  106. col = edge_index[1].type(torch.long).to(device)
  107. edge_feat = data.edges.data['edge_feat'].to(device)
  108. edge_index = SparseTensor(row=row, col=col, value=edge_feat, sparse_sizes=(num_nodes, num_nodes))
  109. x = edge_index.mean(dim=1).to(device)
  110. edge_index.set_value_(None)
  111. train_mask = data.nodes.data['train_mask']
  112. val_mask = data.nodes.data['val_mask']
  113. test_mask = data.nodes.data['test_mask']
  114. split_idx = {
  115. 'train': train_mask,
  116. 'valid': val_mask,
  117. 'test': test_mask
  118. }
  119. labels = data.nodes.data[ylabel]
  120. num_classes = len(np.unique(labels.numpy()))
  121. train_idx = split_idx['train']
  122. if args.use_sage:
  123. model = SAGE(x.size(1), args.hidden_channels, 112,
  124. args.num_layers, args.dropout).to(device)
  125. else:
  126. model = GCN(x.size(1), args.hidden_channels, 112,
  127. args.num_layers, args.dropout).to(device)
  128. # Pre-compute GCN normalization.
  129. adj_t = edge_index.set_diag()
  130. deg = adj_t.sum(dim=1).to(torch.float)
  131. deg_inv_sqrt = deg.pow(-0.5)
  132. deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
  133. adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)
  134. edge_index = adj_t
  135. evaluator = Evaluator(name='ogbn-proteins')
  136. for run in range(args.runs):
  137. model.reset_parameters()
  138. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  139. for epoch in range(1, 1 + args.epochs):
  140. loss = train(model, x, y, edge_index, train_idx, optimizer)
  141. if epoch % args.eval_steps == 0:
  142. result = test(model, x, y, edge_index, split_idx, evaluator)
  143. if epoch % args.log_steps == 0:
  144. train_rocauc, valid_rocauc, test_rocauc = result
  145. print(f'Run: {run + 1:02d}, '
  146. f'Epoch: {epoch:02d}, '
  147. f'Loss: {loss:.4f}, '
  148. f'Train: {100 * train_rocauc:.2f}%, '
  149. f'Valid: {100 * valid_rocauc:.2f}% '
  150. f'Test: {100 * test_rocauc:.2f}%')
  151. if __name__ == "__main__":
  152. main()