You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. Performance check of AutoGL model + PYG (trainer + dataset)
  3. """
  4. import os
  5. import random
  6. import numpy as np
  7. from tqdm import tqdm
  8. import pickle
  9. import torch
  10. import torch.nn.functional as F
  11. from torch_geometric.datasets import Planetoid
  12. import torch_geometric.transforms as T
  13. from torch_geometric.nn import GCNConv, GATConv, SAGEConv
  14. import logging
  15. logging.basicConfig(level=logging.ERROR)
  16. class GCN(torch.nn.Module):
  17. def __init__(self, num_features, num_classes):
  18. super(GCN, self).__init__()
  19. self.conv1 = GCNConv(num_features, 16)
  20. self.conv2 = GCNConv(16, num_classes)
  21. def forward(self, data):
  22. x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
  23. x = F.relu(self.conv1(x, edge_index, edge_weight))
  24. x = F.dropout(x, training=self.training)
  25. x = self.conv2(x, edge_index, edge_weight)
  26. return F.log_softmax(x, dim=1)
  27. class GAT(torch.nn.Module):
  28. def __init__(self, num_features, num_classes):
  29. super(GAT, self).__init__()
  30. self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
  31. self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False,
  32. dropout=0.6)
  33. def forward(self, data):
  34. x, edge_index = data.x, data.edge_index
  35. x = F.dropout(x, p=0.6, training=self.training)
  36. x = F.elu(self.conv1(x, edge_index))
  37. x = F.dropout(x, p=0.6, training=self.training)
  38. x = self.conv2(x, edge_index)
  39. return F.log_softmax(x, dim=-1)
  40. class SAGE(torch.nn.Module):
  41. def __init__(self, num_features, hidden_channels, num_layers, num_classes):
  42. super(SAGE, self).__init__()
  43. self.num_layers = num_layers
  44. self.convs = torch.nn.ModuleList()
  45. for i in range(num_layers):
  46. inc = outc = hidden_channels
  47. if i == 0:
  48. inc = num_features
  49. if i == num_layers - 1:
  50. outc = num_classes
  51. self.convs.append(SAGEConv(inc, outc))
  52. def forward(self, data):
  53. x, edge_index = data.x, data.edge_index
  54. for i, conv in enumerate(self.convs):
  55. x = conv(x, edge_index)
  56. if i != self.num_layers - 1:
  57. x = x.relu()
  58. x = F.dropout(x, p=0.5, training=self.training)
  59. return F.log_softmax(x, dim=-1)
  60. def test(model, data, mask):
  61. model.eval()
  62. pred = model(data)[mask].max(1)[1]
  63. acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
  64. return acc
  65. def train(model, data, args):
  66. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  67. parameters = model.state_dict()
  68. best_acc = 0.
  69. for epoch in range(args.epoch):
  70. model.train()
  71. optimizer.zero_grad()
  72. output = model(data)
  73. loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
  74. loss.backward()
  75. optimizer.step()
  76. val_acc = test(model, data, data.val_mask)
  77. if val_acc > best_acc:
  78. best_acc = val_acc
  79. parameters = pickle.dumps(model.state_dict())
  80. model.load_state_dict(pickle.loads(parameters))
  81. return model
  82. if __name__ == '__main__':
  83. import argparse
  84. parser = argparse.ArgumentParser('pyg model')
  85. parser.add_argument('--device', type=str, default='cuda')
  86. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  87. parser.add_argument('--repeat', type=int, default=50)
  88. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  89. parser.add_argument('--lr', type=float, default=0.01)
  90. parser.add_argument('--weight_decay', type=float, default=0.0)
  91. parser.add_argument('--epoch', type=int, default=200)
  92. args = parser.parse_args()
  93. # seed = 100
  94. dataset = Planetoid(os.path.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures())
  95. data = dataset[0].to(args.device)
  96. accs = []
  97. for seed in tqdm(range(args.repeat)):
  98. np.random.seed(seed)
  99. torch.manual_seed(seed)
  100. if args.model == 'gat':
  101. model = GAT(dataset.num_node_features, dataset.num_classes)
  102. elif args.model == 'gcn':
  103. model = GCN(dataset.num_node_features, dataset.num_classes)
  104. elif args.model == 'sage':
  105. model = SAGE(dataset.num_node_features, 64, 2, dataset.num_classes)
  106. model.to(args.device)
  107. train(model, data, args)
  108. acc = test(model, data, data.test_mask)
  109. accs.append(acc)
  110. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))