You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nodeclf_ogb.py 7.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import os
  2. import tqdm
  3. import argparse
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from torch_geometric.nn import GCNConv, SAGEConv
  8. from ogb.nodeproppred import Evaluator
  9. from autogl.datasets import build_dataset_from_name
  10. from autogl import backend
  11. if backend.DependentBackend.is_dgl():
  12. feat = 'feat'
  13. label = 'label'
  14. else:
  15. feat = 'x'
  16. label = 'y'
  17. class GCN(torch.nn.Module):
  18. def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
  19. dropout):
  20. super(GCN, self).__init__()
  21. self.convs = torch.nn.ModuleList()
  22. self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
  23. self.bns = torch.nn.ModuleList()
  24. self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
  25. for _ in range(num_layers - 2):
  26. self.convs.append(
  27. GCNConv(hidden_channels, hidden_channels, cached=True))
  28. self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
  29. self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
  30. self.dropout = dropout
  31. def reset_parameters(self):
  32. for conv in self.convs:
  33. conv.reset_parameters()
  34. for bn in self.bns:
  35. bn.reset_parameters()
  36. def forward(self, x, adj_t):
  37. for i, conv in enumerate(self.convs[:-1]):
  38. x = conv(x, adj_t)
  39. x = self.bns[i](x)
  40. x = F.relu(x)
  41. x = F.dropout(x, p=self.dropout, training=self.training)
  42. x = self.convs[-1](x, adj_t)
  43. return x.log_softmax(dim=-1)
  44. class SAGE(torch.nn.Module):
  45. def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
  46. dropout):
  47. super(SAGE, self).__init__()
  48. self.convs = torch.nn.ModuleList()
  49. self.convs.append(SAGEConv(in_channels, hidden_channels))
  50. self.bns = torch.nn.ModuleList()
  51. self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
  52. for _ in range(num_layers - 2):
  53. self.convs.append(SAGEConv(hidden_channels, hidden_channels))
  54. self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
  55. self.convs.append(SAGEConv(hidden_channels, out_channels))
  56. self.dropout = dropout
  57. def reset_parameters(self):
  58. for conv in self.convs:
  59. conv.reset_parameters()
  60. for bn in self.bns:
  61. bn.reset_parameters()
  62. def forward(self, x, adj_t):
  63. for i, conv in enumerate(self.convs[:-1]):
  64. x = conv(x, adj_t)
  65. x = self.bns[i](x)
  66. x = F.relu(x)
  67. x = F.dropout(x, p=self.dropout, training=self.training)
  68. x = self.convs[-1](x, adj_t)
  69. return x.log_softmax(dim=-1)
  70. def train(model, x, y, edge_index, train_idx, optimizer):
  71. model.train()
  72. optimizer.zero_grad()
  73. out = model(x, edge_index)[train_idx]
  74. loss = F.nll_loss(out, y[train_idx])
  75. loss.backward()
  76. optimizer.step()
  77. return loss.item()
  78. @torch.no_grad()
  79. def test(model, x, y, edge_index, split_idx, evaluator):
  80. model.eval()
  81. out = model(x, edge_index)
  82. y_pred = out.argmax(dim=-1, keepdim=True)
  83. train_acc = evaluator.eval({
  84. 'y_true': y[split_idx['train']].view(-1, 1),
  85. 'y_pred': y_pred[split_idx['train']],
  86. })['acc']
  87. valid_acc = evaluator.eval({
  88. 'y_true': y[split_idx['valid']].view(-1, 1),
  89. 'y_pred': y_pred[split_idx['valid']],
  90. })['acc']
  91. test_acc = evaluator.eval({
  92. 'y_true': y[split_idx['test']].view(-1, 1),
  93. 'y_pred': y_pred[split_idx['test']],
  94. })['acc']
  95. return train_acc, valid_acc, test_acc
  96. class Node:
  97. def __init__(self, a, b):
  98. self.a = a
  99. self.b = b
  100. def __le__(self, other):
  101. return self.a <= other.a
  102. def __lt__(self, other):
  103. if self.a < other.a:
  104. return True
  105. elif self.a == other.a:
  106. return self.b < other.b
  107. else:
  108. return False
  109. def main():
  110. parser = argparse.ArgumentParser(description='OGBN-Arxiv (GNN)')
  111. parser.add_argument('--device', type=int, default=0)
  112. parser.add_argument('--log_steps', type=int, default=1)
  113. parser.add_argument('--use_sage', action='store_true')
  114. parser.add_argument('--num_layers', type=int, default=3)
  115. parser.add_argument('--hidden_channels', type=int, default=256)
  116. parser.add_argument('--dropout', type=float, default=0.5)
  117. parser.add_argument('--lr', type=float, default=0.01)
  118. parser.add_argument('--epochs', type=int, default=500)
  119. parser.add_argument('--runs', type=int, default=10)
  120. args = parser.parse_args()
  121. print(args)
  122. device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
  123. device = torch.device(device)
  124. # print(oedge_index)
  125. dataset = build_dataset_from_name('ogbn_arxiv')
  126. data = dataset[0]
  127. x = data.nodes.data[feat].to(device)
  128. y = data.nodes.data[label].to(device)
  129. edge_index = data.edges.connections.to(device)
  130. # edge_index = data_transfer(edge_index, row, col)
  131. print(edge_index)
  132. # print(edge_index.shape)
  133. train_mask = data.nodes.data['train_mask']
  134. val_mask = data.nodes.data['val_mask']
  135. test_mask = data.nodes.data['test_mask']
  136. split_idx = {
  137. 'train': train_mask,
  138. 'valid': val_mask,
  139. 'test': test_mask
  140. }
  141. # split_idx = dataset.get_idx_split()
  142. train_idx = split_idx['train'].to(device)
  143. labels = dataset[0].nodes.data[label]
  144. num_classes = len(np.unique(labels.numpy()))
  145. if args.use_sage:
  146. model = SAGE(dataset[0].nodes.data[feat].size(1), args.hidden_channels,
  147. num_classes, args.num_layers,
  148. args.dropout).to(device)
  149. else:
  150. model = GCN(dataset[0].nodes.data[feat].size(1), args.hidden_channels,
  151. num_classes, args.num_layers,
  152. args.dropout).to(device)
  153. evaluator = Evaluator(name='ogbn-arxiv')
  154. best_accs = []
  155. for run in range(args.runs):
  156. model.reset_parameters()
  157. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  158. best_valid = 0.0
  159. best_test = 0.0
  160. for epoch in range(1, 1 + args.epochs):
  161. loss = train(model, x, y, edge_index, train_idx, optimizer)
  162. result = test(model, x, y, edge_index, split_idx, evaluator)
  163. if epoch % args.log_steps == 0:
  164. train_acc, valid_acc, test_acc = result
  165. print(f'Run: {run + 1:02d}, '
  166. f'Epoch: {epoch:02d}, '
  167. f'Loss: {loss:.4f}, '
  168. f'Train: {100 * train_acc:.2f}%, '
  169. f'Valid: {100 * valid_acc:.2f}% '
  170. f'Test: {100 * test_acc:.2f}%')
  171. if valid_acc > best_valid:
  172. best_valid = valid_acc
  173. best_test = test_acc
  174. best_accs.append(best_test)
  175. print(best_accs)
  176. print(np.mean(best_accs))
  177. print(np.std(best_accs))
  178. if __name__ == "__main__":
  179. main()