You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 4.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. Performance check of DGL model + trainer + dataset
  3. """
  4. import numpy as np
  5. from tqdm import tqdm
  6. import pickle
  7. import torch
  8. import torch.nn.functional as F
  9. from dgl.data import CoraGraphDataset, PubmedGraphDataset, CiteseerGraphDataset
  10. from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
  11. import logging
  12. logging.basicConfig(level=logging.ERROR)
  13. class GCN(torch.nn.Module):
  14. def __init__(self, num_features, num_classes):
  15. super(GCN, self).__init__()
  16. self.conv1 = GraphConv(num_features, 16)
  17. self.conv2 = GraphConv(16, num_classes)
  18. def forward(self, graph):
  19. features = graph.ndata['feat']
  20. features = F.relu(self.conv1(graph, features))
  21. features = F.dropout(features, training=self.training)
  22. features = self.conv2(graph, features)
  23. return F.log_softmax(features, dim=-1)
  24. class GAT(torch.nn.Module):
  25. def __init__(self, num_features, num_classes):
  26. super(GAT, self).__init__()
  27. self.conv1 = GATConv(num_features, 8, 8, feat_drop=.6, attn_drop=.6, activation=F.relu)
  28. self.conv2 = GATConv(8 * 8, num_classes, 1, feat_drop=.6, attn_drop=.6)
  29. def forward(self, graph):
  30. features = graph.ndata['feat']
  31. features = self.conv1(graph, features).flatten(1)
  32. features = self.conv2(graph, features).mean(1)
  33. return F.log_softmax(features, dim=-1)
  34. class SAGE(torch.nn.Module):
  35. def __init__(self, num_features, hidden_channels, num_layers, num_classes):
  36. super(SAGE, self).__init__()
  37. self.num_layers = num_layers
  38. self.convs = torch.nn.ModuleList()
  39. for i in range(num_layers):
  40. inc = outc = hidden_channels
  41. if i == 0:
  42. inc = num_features
  43. if i == num_layers - 1:
  44. outc = num_classes
  45. self.convs.append(SAGEConv(inc, outc, "gcn"))
  46. self.dropout = torch.nn.Dropout()
  47. def forward(self, graph):
  48. h = graph.ndata['feat']
  49. h = self.dropout(h)
  50. for i, conv in enumerate(self.convs):
  51. h = conv(graph, h)
  52. if i != self.num_layers - 1:
  53. h = h.relu()
  54. h = self.dropout(h)
  55. return F.log_softmax(h, dim=-1)
  56. def test(model, graph, mask, label):
  57. model.eval()
  58. pred = model(graph)[mask].max(1)[1]
  59. acc = pred.eq(label[mask]).sum().item() / mask.sum().item()
  60. return acc
  61. def train(model, graph, args, label, train_mask, val_mask):
  62. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  63. parameters = model.state_dict()
  64. best_acc = 0.
  65. for epoch in range(args.epoch):
  66. model.train()
  67. optimizer.zero_grad()
  68. output = model(graph)
  69. loss = F.nll_loss(output[train_mask], label[train_mask])
  70. loss.backward()
  71. optimizer.step()
  72. val_acc = test(model, graph, val_mask, label)
  73. if val_acc > best_acc:
  74. best_acc = val_acc
  75. parameters = pickle.dumps(model.state_dict())
  76. model.load_state_dict(pickle.loads(parameters))
  77. return model
  78. if __name__ == '__main__':
  79. import argparse
  80. parser = argparse.ArgumentParser('dgl')
  81. parser.add_argument('--device', type=str, default='cuda')
  82. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  83. parser.add_argument('--repeat', type=int, default=50)
  84. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  85. parser.add_argument('--lr', type=float, default=0.01)
  86. parser.add_argument('--weight_decay', type=float, default=0.0)
  87. parser.add_argument('--epoch', type=int, default=200)
  88. args = parser.parse_args()
  89. # seed = 100
  90. if args.dataset == 'Cora':
  91. dataset = CoraGraphDataset()
  92. elif args.dataset == 'CiteSeer':
  93. dataset = CiteseerGraphDataset()
  94. elif args.dataset == 'PubMed':
  95. dataset = PubmedGraphDataset()
  96. graph = dataset[0].to(args.device)
  97. label = graph.ndata['label']
  98. train_mask = graph.ndata['train_mask']
  99. val_mask = graph.ndata['val_mask']
  100. test_mask = graph.ndata['test_mask']
  101. accs = []
  102. for seed in tqdm(range(args.repeat)):
  103. np.random.seed(seed)
  104. torch.manual_seed(seed)
  105. if args.model == 'gat':
  106. model = GAT(graph.ndata['feat'].size(1), dataset.num_classes)
  107. elif args.model == 'gcn':
  108. model = GCN(graph.ndata['feat'].size(1), dataset.num_classes)
  109. elif args.model == 'sage':
  110. model = SAGE(graph.ndata['feat'].size(1), 64, 2, dataset.num_classes)
  111. model.to(args.device)
  112. train(model, graph, args, label, train_mask, val_mask)
  113. acc = test(model, graph, test_mask, label)
  114. accs.append(acc)
  115. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))