You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. """
  2. Performance check of AutoGL model + PYG (trainer + dataset)
  3. """
  4. import os
  5. import random
  6. import numpy as np
  7. from tqdm import tqdm
  8. os.environ["AUTOGL_BACKEND"] = "pyg"
  9. import torch
  10. import torch.nn.functional as F
  11. from torch_geometric.datasets import Planetoid
  12. import torch_geometric.transforms as T
  13. from autogl.module.model.pyg import AutoGCN, AutoGAT, AutoSAGE
  14. from autogl.datasets import utils
  15. from autogl.solver.utils import set_seed
  16. import logging
  17. logging.basicConfig(level=logging.ERROR)
  18. def test(model, data, mask):
  19. model.eval()
  20. if hasattr(model, 'cls_forward'):
  21. out = model.cls_forward(data)[mask]
  22. else:
  23. out = model(data)[mask]
  24. pred = out.max(1)[1]
  25. acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
  26. return acc
  27. def train(model, data, args):
  28. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  29. parameters = model.state_dict()
  30. best_acc = 0.
  31. for epoch in range(args.epoch):
  32. model.train()
  33. optimizer.zero_grad()
  34. if hasattr(model, 'cls_forward'):
  35. output = model.cls_forward(data)
  36. else:
  37. output = model(data)
  38. loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
  39. loss.backward()
  40. optimizer.step()
  41. val_acc = test(model, data, data.val_mask)
  42. if val_acc > best_acc:
  43. best_acc = val_acc
  44. parameters = model.state_dict()
  45. model.load_state_dict(parameters)
  46. return model
  47. if __name__ == '__main__':
  48. import argparse
  49. parser = argparse.ArgumentParser('pyg model')
  50. parser.add_argument('--device', type=str, default='cuda')
  51. parser.add_argument('--dataset', type=str, choices=['Cora', 'CiteSeer', 'PubMed'], default='Cora')
  52. parser.add_argument('--repeat', type=int, default=50)
  53. parser.add_argument('--model', type=str, choices=['gat', 'gcn', 'sage'], default='gat')
  54. parser.add_argument('--lr', type=float, default=0.01)
  55. parser.add_argument('--weight_decay', type=float, default=0.0)
  56. parser.add_argument('--epoch', type=int, default=200)
  57. args = parser.parse_args()
  58. # seed = 100
  59. dataset = Planetoid(os.path.expanduser('~/.cache-autogl'), args.dataset, transform=T.NormalizeFeatures())
  60. data = dataset[0].to(args.device)
  61. accs = []
  62. for seed in tqdm(range(args.repeat)):
  63. set_seed(seed)
  64. if args.model == 'gat':
  65. model = AutoGAT(
  66. num_features=dataset.num_node_features,
  67. num_classes=dataset.num_classes,
  68. device=args.device,
  69. init=False
  70. ).from_hyper_parameter({
  71. # hp from model
  72. "num_layers": 2,
  73. "hidden": [8],
  74. "heads": 8,
  75. "dropout": 0.6,
  76. "act": "elu",
  77. }).model
  78. elif args.model == 'gcn':
  79. model = AutoGCN(
  80. num_features=dataset.num_node_features,
  81. num_classes=dataset.num_classes,
  82. device=args.device,
  83. init=False
  84. ).from_hyper_parameter({
  85. "num_layers": 2,
  86. "hidden": [16],
  87. "dropout": 0.5,
  88. "act": "relu"
  89. }).model
  90. elif args.model == 'sage':
  91. model = AutoSAGE(
  92. num_features=dataset.num_node_features,
  93. num_classes=dataset.num_classes,
  94. device=args.device,
  95. init=False
  96. ).from_hyper_parameter({
  97. "num_layers": 2,
  98. "hidden": [64],
  99. "dropout": 0.5,
  100. "act": "relu",
  101. "agg": "mean",
  102. }).model
  103. model.to(args.device)
  104. train(model, data, args)
  105. acc = test(model, data, data.test_mask)
  106. accs.append(acc)
  107. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))