You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. Performance check of AutoGL model + DGL (trainer + dataset)
  3. """
  4. import os
  5. import numpy as np
  6. from tqdm import tqdm
  7. import dgl
  8. os.environ["AUTOGL_BACKEND"] = "dgl"
  9. import sys
  10. sys.path.append("../../../../")
  11. import torch
  12. import torch.nn.functional as F
  13. from dgl.data import CoraGraphDataset, PubmedGraphDataset, CiteseerGraphDataset
  14. from autogl.module.model.dgl import AutoGCN, AutoGAT, AutoSAGE
  15. from autogl.solver.utils import set_seed
  16. import logging
  17. logging.basicConfig(level=logging.ERROR)
  18. def test(model, graph, mask, label):
  19. model.eval()
  20. pred = model(graph)[mask].max(1)[1]
  21. acc = pred.eq(label[mask]).sum().item() / mask.sum().item()
  22. return acc
  23. def train(model, graph, args, label, train_mask, val_mask):
  24. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  25. parameters = model.state_dict()
  26. best_acc = 0.
  27. for epoch in range(args.epoch):
  28. model.train()
  29. optimizer.zero_grad()
  30. output = model(graph)
  31. loss = F.nll_loss(output[train_mask], label[train_mask])
  32. loss.backward()
  33. optimizer.step()
  34. val_acc = test(model, graph, val_mask, label)
  35. if val_acc > best_acc:
  36. best_acc = val_acc
  37. parameters = model.state_dict()
  38. model.load_state_dict(parameters)
  39. return model
  40. if __name__ == '__main__':
  41. import argparse
  42. parser = argparse.ArgumentParser('dgl model')
  43. parser.add_argument('--device', type=str, default='cuda')
  44. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  45. parser.add_argument('--repeat', type=int, default=50)
  46. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  47. parser.add_argument('--lr', type=float, default=0.01)
  48. parser.add_argument('--weight_decay', type=float, default=0.0)
  49. parser.add_argument('--epoch', type=int, default=200)
  50. args = parser.parse_args()
  51. # seed = 100
  52. if args.dataset == 'Cora':
  53. dataset = CoraGraphDataset()
  54. elif args.dataset == 'CiteSeer':
  55. dataset = CiteseerGraphDataset()
  56. elif args.dataset == 'PubMed':
  57. dataset = PubmedGraphDataset()
  58. graph = dataset[0].to(args.device)
  59. # graph = dgl.remove_self_loop(graph)
  60. # graph = dgl.add_self_loop(graph)
  61. label = graph.ndata['label']
  62. train_mask = graph.ndata['train_mask']
  63. val_mask = graph.ndata['val_mask']
  64. test_mask = graph.ndata['test_mask']
  65. num_features = graph.ndata['feat'].size(1)
  66. num_classes = dataset.num_classes
  67. accs = []
  68. for seed in tqdm(range(args.repeat)):
  69. set_seed(seed)
  70. if args.model == 'gat':
  71. model = AutoGAT(
  72. num_features=num_features,
  73. num_classes=num_classes,
  74. device=args.device,
  75. init=False
  76. ).from_hyper_parameter({
  77. # hp from model
  78. "num_layers": 2,
  79. "hidden": [8],
  80. "heads": 8,
  81. "feat_drop": 0.6,
  82. "dropout": 0.6,
  83. "act": "relu",
  84. }).model
  85. elif args.model == 'gcn':
  86. model = AutoGCN(
  87. num_features=num_features,
  88. num_classes=num_classes,
  89. device=args.device,
  90. init=False
  91. ).from_hyper_parameter({
  92. "num_layers": 2,
  93. "hidden": [16],
  94. "dropout": 0.5,
  95. "act": "relu"
  96. }).model
  97. elif args.model == 'sage':
  98. model = AutoSAGE(
  99. num_features=num_features,
  100. num_classes=num_classes,
  101. device=args.device,
  102. init=False
  103. ).from_hyper_parameter({
  104. "num_layers": 2,
  105. "hidden": [64],
  106. "dropout": 0.5,
  107. "act": "relu",
  108. "agg": "gcn",
  109. }).model
  110. model.to(args.device)
  111. train(model, graph, args, label, train_mask, val_mask)
  112. acc = test(model, graph, test_mask, label)
  113. accs.append(acc)
  114. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))