You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """
  2. Performance check of AutoGL model + PYG (trainer + dataset)
  3. """
  4. import os
  5. import random
  6. import numpy as np
  7. from tqdm import tqdm
  8. os.environ["AUTOGL_BACKEND"] = "pyg"
  9. import torch
  10. import torch.nn.functional as F
  11. from torch_geometric.datasets import TUDataset
  12. from torch_geometric.data import DataLoader
  13. from autogl.module.model.pyg import AutoGIN, AutoTopkpool
  14. from autogl.datasets import utils
  15. from autogl.solver.utils import set_seed
  16. import logging
  17. logging.basicConfig(level=logging.ERROR)
  18. def test(model, loader, args):
  19. model.eval()
  20. correct = 0
  21. for data in loader:
  22. data = data.to(args.device)
  23. output = model(data)
  24. pred = output.max(dim=1)[1]
  25. correct += pred.eq(data.y).sum().item()
  26. return correct / len(loader.dataset)
  27. def train(model, train_loader, val_loader, args):
  28. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
  29. parameters = model.state_dict()
  30. best_acc = 0.
  31. for epoch in range(args.epoch):
  32. model.train()
  33. for data in train_loader:
  34. data = data.to(args.device)
  35. optimizer.zero_grad()
  36. output = model(data)
  37. loss = F.nll_loss(output, data.y)
  38. loss.backward()
  39. optimizer.step()
  40. val_acc = test(model, val_loader, args)
  41. if val_acc > best_acc:
  42. best_acc = val_acc
  43. parameters = model.state_dict()
  44. model.load_state_dict(parameters)
  45. return model
  46. if __name__ == '__main__':
  47. import argparse
  48. parser = argparse.ArgumentParser('pyg trainer')
  49. parser.add_argument('--device', type=str, default='cuda')
  50. parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
  51. parser.add_argument('--dataset_seed', type=int, default=2021)
  52. parser.add_argument('--batch_size', type=int, default=32)
  53. parser.add_argument('--repeat', type=int, default=50)
  54. parser.add_argument('--model', type=str, choices=['gin', 'topkpool'], default='gin')
  55. parser.add_argument('--lr', type=float, default=0.0001)
  56. parser.add_argument('--epoch', type=int, default=100)
  57. args = parser.parse_args()
  58. # seed = 100
  59. dataset = TUDataset(os.path.expanduser('~/.pyg'), args.dataset)
  60. # 1. split dataset [fix split]
  61. dataids = list(range(len(dataset)))
  62. random.seed(args.dataset_seed)
  63. random.shuffle(dataids)
  64. fold = int(len(dataset) * 0.1)
  65. train_index = dataids[:fold * 8]
  66. val_index = dataids[fold * 8: fold * 9]
  67. test_index = dataids[fold * 9: ]
  68. dataset.train_index = train_index
  69. dataset.val_index = val_index
  70. dataset.test_index = test_index
  71. dataset.train_split = dataset[dataset.train_index]
  72. dataset.val_split = dataset[dataset.val_index]
  73. dataset.test_split = dataset[dataset.test_index]
  74. labels = np.array([data.y.item() for data in dataset.test_split])
  75. train_loader = DataLoader(dataset.train_split, batch_size=args.batch_size)
  76. val_loader = DataLoader(dataset.val_split, batch_size=args.batch_size)
  77. test_loader = DataLoader(dataset.test_split, batch_size=args.batch_size)
  78. accs = []
  79. for seed in tqdm(range(args.repeat)):
  80. set_seed(seed)
  81. if args.model == 'gin':
  82. model = AutoGIN(
  83. num_features=dataset.num_node_features,
  84. num_classes=dataset.num_classes,
  85. num_graph_features=0,
  86. init=False
  87. ).from_hyper_parameter({
  88. # hp from model
  89. "num_layers": 5,
  90. "hidden": [64,64,64,64],
  91. "dropout": 0.5,
  92. "act": "relu",
  93. "eps": "False",
  94. "mlp_layers": 2,
  95. "neighbor_pooling_type": "sum",
  96. "graph_pooling_type": "sum"
  97. }).model
  98. elif args.model == 'topkpool':
  99. model = AutoTopkpool(
  100. num_features=dataset.num_node_features,
  101. num_classes=dataset.num_classes,
  102. num_graph_features=0,
  103. init=False
  104. ).from_hyper_parameter({
  105. "ratio": 0.8,
  106. "dropout": 0.5,
  107. "act": "relu"
  108. }).model
  109. model.to(args.device)
  110. train(model, train_loader, val_loader, args)
  111. acc = test(model, test_loader, args)
  112. accs.append(acc)
  113. print('{:.4f} ~ {:.4f}'.format(np.mean(accs), np.std(accs)))