You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Performance check of DGL original dataset, model, trainer setting
  3. Borrowed from DGL official examples: https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin
  4. TopkPool is not supported currently
  5. """
  6. # from dgl.dataloading.pytorch.dataloader import GraphDataLoader
  7. import pickle
  8. from dgl.dataloading import GraphDataLoader
  9. import numpy as np
  10. from tqdm import tqdm
  11. import random
  12. import torch
  13. import torch.nn as nn
  14. import torch.optim as optim
  15. from dgl.data import GINDataset
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from dgl.nn.pytorch.conv import GINConv
  20. from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
  21. def set_seed(seed=None):
  22. """
  23. Set seed of whole process
  24. """
  25. if seed is None:
  26. seed = random.randint(0, 5000)
  27. random.seed(seed)
  28. np.random.seed(seed)
  29. torch.manual_seed(seed)
  30. if torch.cuda.is_available():
  31. torch.cuda.manual_seed_all(seed)
  32. torch.backends.cudnn.deterministic = True
  33. torch.backends.cudnn.benchmark = False
  34. class DatasetAbstraction():
  35. def __init__(self, graphs, labels):
  36. for g in graphs:
  37. g.ndata['feat'] = g.ndata['attr']
  38. self.graphs, self.labels = [], []
  39. for g, l in zip(graphs, labels):
  40. self.graphs.append(g)
  41. self.labels.append(l)
  42. self.gclasses = max(self.labels).item() + 1
  43. self.graph = self.graphs
  44. def __len__(self):
  45. return len(self.graphs)
  46. def __getitem__(self, idx):
  47. if isinstance(idx, int):
  48. return self.graphs[idx], self.labels[idx]
  49. elif isinstance(idx, torch.BoolTensor):
  50. idx = [i for i in range(len(idx)) if idx[i]]
  51. elif isinstance(idx, torch.Tensor) and idx.unique()[0].sum().item() == 1:
  52. idx = [i for i in range(len(idx)) if idx[i]]
  53. return DatasetAbstraction([self.graphs[i] for i in idx], [self.labels[i] for i in idx])
  54. class ApplyNodeFunc(nn.Module):
  55. """Update the node feature hv with MLP, BN and ReLU."""
  56. def __init__(self, mlp):
  57. super(ApplyNodeFunc, self).__init__()
  58. self.mlp = mlp
  59. self.bn = nn.BatchNorm1d(self.mlp.output_dim)
  60. def forward(self, h):
  61. h = self.mlp(h)
  62. h = self.bn(h)
  63. h = F.relu(h)
  64. return h
  65. class MLP(nn.Module):
  66. """MLP with linear output"""
  67. def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
  68. """MLP layers construction
  69. Paramters
  70. ---------
  71. num_layers: int
  72. The number of linear layers
  73. input_dim: int
  74. The dimensionality of input features
  75. hidden_dim: int
  76. The dimensionality of hidden units at ALL layers
  77. output_dim: int
  78. The number of classes for prediction
  79. """
  80. super(MLP, self).__init__()
  81. self.linear_or_not = True # default is linear model
  82. self.num_layers = num_layers
  83. self.output_dim = output_dim
  84. if num_layers < 1:
  85. raise ValueError("number of layers should be positive!")
  86. elif num_layers == 1:
  87. # Linear model
  88. self.linear = nn.Linear(input_dim, output_dim)
  89. else:
  90. # Multi-layer model
  91. self.linear_or_not = False
  92. self.linears = torch.nn.ModuleList()
  93. self.batch_norms = torch.nn.ModuleList()
  94. self.linears.append(nn.Linear(input_dim, hidden_dim))
  95. for layer in range(num_layers - 2):
  96. self.linears.append(nn.Linear(hidden_dim, hidden_dim))
  97. self.linears.append(nn.Linear(hidden_dim, output_dim))
  98. for layer in range(num_layers - 1):
  99. self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
  100. def forward(self, x):
  101. if self.linear_or_not:
  102. # If linear model
  103. return self.linear(x)
  104. else:
  105. # If MLP
  106. h = x
  107. for i in range(self.num_layers - 1):
  108. h = F.relu(self.batch_norms[i](self.linears[i](h)))
  109. return self.linears[-1](h)
  110. class GIN(nn.Module):
  111. """GIN model"""
  112. def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
  113. output_dim, final_dropout, learn_eps, graph_pooling_type,
  114. neighbor_pooling_type):
  115. """model parameters setting
  116. Paramters
  117. ---------
  118. num_layers: int
  119. The number of linear layers in the neural network
  120. num_mlp_layers: int
  121. The number of linear layers in mlps
  122. input_dim: int
  123. The dimensionality of input features
  124. hidden_dim: int
  125. The dimensionality of hidden units at ALL layers
  126. output_dim: int
  127. The number of classes for prediction
  128. final_dropout: float
  129. dropout ratio on the final linear layer
  130. learn_eps: boolean
  131. If True, learn epsilon to distinguish center nodes from neighbors
  132. If False, aggregate neighbors and center nodes altogether.
  133. neighbor_pooling_type: str
  134. how to aggregate neighbors (sum, mean, or max)
  135. graph_pooling_type: str
  136. how to aggregate entire nodes in a graph (sum, mean or max)
  137. """
  138. super(GIN, self).__init__()
  139. self.num_layers = num_layers
  140. self.learn_eps = learn_eps
  141. # List of MLPs
  142. self.ginlayers = torch.nn.ModuleList()
  143. self.batch_norms = torch.nn.ModuleList()
  144. for layer in range(self.num_layers - 1):
  145. if layer == 0:
  146. mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
  147. else:
  148. mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
  149. self.ginlayers.append(
  150. GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
  151. self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
  152. # Linear function for graph poolings of output of each layer
  153. # which maps the output of different layers into a prediction score
  154. self.linears_prediction = torch.nn.ModuleList()
  155. for layer in range(num_layers):
  156. if layer == 0:
  157. self.linears_prediction.append(
  158. nn.Linear(input_dim, output_dim))
  159. else:
  160. self.linears_prediction.append(
  161. nn.Linear(hidden_dim, output_dim))
  162. self.drop = nn.Dropout(final_dropout)
  163. if graph_pooling_type == 'sum':
  164. self.pool = SumPooling()
  165. elif graph_pooling_type == 'mean':
  166. self.pool = AvgPooling()
  167. elif graph_pooling_type == 'max':
  168. self.pool = MaxPooling()
  169. else:
  170. raise NotImplementedError
  171. def forward(self, g, h):
  172. # list of hidden representation at each layer (including input)
  173. hidden_rep = [h]
  174. for i in range(self.num_layers - 1):
  175. h = self.ginlayers[i](g, h)
  176. h = self.batch_norms[i](h)
  177. h = F.relu(h)
  178. hidden_rep.append(h)
  179. score_over_layer = 0
  180. # perform pooling over all nodes in each graph in every layer
  181. for i, h in enumerate(hidden_rep):
  182. pooled_h = self.pool(g, h)
  183. score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
  184. return score_over_layer
  185. def train(net, trainloader, validloader, optimizer, criterion, epoch, device):
  186. best_model = pickle.dumps(net.state_dict())
  187. best_acc = 0.
  188. for e in range(epoch):
  189. net.train()
  190. for graphs, labels in trainloader:
  191. labels = labels.to(device)
  192. graphs = graphs.to(device)
  193. feat = graphs.ndata.pop('attr')
  194. outputs = net(graphs, feat)
  195. loss = criterion(outputs, labels)
  196. # backprop
  197. optimizer.zero_grad()
  198. loss.backward()
  199. optimizer.step()
  200. gt = []
  201. pr = []
  202. net.eval()
  203. for graphs, labels in validloader:
  204. labels = labels.to(device)
  205. graphs = graphs.to(device)
  206. gt.append(labels)
  207. feat = graphs.ndata.pop('attr')
  208. outputs = net(graphs, feat)
  209. pr.append(outputs.argmax(1))
  210. gt = torch.cat(gt, dim=0)
  211. pr = torch.cat(pr, dim=0)
  212. acc = (gt == pr).float().mean().item()
  213. if acc > best_acc:
  214. best_acc = acc
  215. best_model = pickle.dumps(net.state_dict())
  216. net.load_state_dict(pickle.loads(best_model))
  217. return net
  218. def eval_net(net, dataloader, device):
  219. net.eval()
  220. total = 0
  221. total_correct = 0
  222. for data in dataloader:
  223. graphs, labels = data
  224. graphs = graphs.to(device)
  225. labels = labels.to(device)
  226. feat = graphs.ndata.pop('attr')
  227. total += len(labels)
  228. outputs = net(graphs, feat)
  229. _, predicted = torch.max(outputs.data, 1)
  230. total_correct += (predicted == labels.data).sum().item()
  231. acc = 1.0 * total_correct / total
  232. net.train()
  233. return acc
  234. def main():
  235. import argparse
  236. parser = argparse.ArgumentParser()
  237. parser.add_argument("--repeat", type=int, default=10)
  238. parser.add_argument('--dataset', type=str, choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'], default='MUTAG')
  239. args = parser.parse_args()
  240. device = torch.device('cuda')
  241. dataset_ = GINDataset(args.dataset, False)
  242. dataset = DatasetAbstraction([g[0] for g in dataset_], [g[1] for g in dataset_])
  243. # 1. split dataset [fix split]
  244. dataids = list(range(len(dataset)))
  245. random.seed(2021)
  246. random.shuffle(dataids)
  247. fold = int(len(dataset) * 0.1)
  248. train_dataset = dataset[dataids[:fold * 8]]
  249. val_dataset = dataset[dataids[fold * 8: fold * 9]]
  250. test_dataset = dataset[dataids[fold * 9: ]]
  251. trainloader = GraphDataLoader(train_dataset, batch_size=32, shuffle=True)
  252. valloader = GraphDataLoader(val_dataset, batch_size=32, shuffle=False)
  253. testloader = GraphDataLoader(test_dataset, batch_size=32, shuffle=False)
  254. accs = []
  255. for seed in tqdm(range(args.repeat)):
  256. # set up seeds, args.seed supported
  257. set_seed(seed)
  258. model = GIN(
  259. 5, 2, dataset_.dim_nfeats, 64, dataset_.gclasses, 0.5, False,
  260. "sum", "sum").to(device)
  261. criterion = nn.CrossEntropyLoss() # defaul reduce is true
  262. optimizer = optim.Adam(model.parameters(), lr=0.0001)
  263. model = train(model, trainloader, valloader, optimizer, criterion, 100, device)
  264. acc = eval_net(model, testloader, device)
  265. accs.append(acc)
  266. print('{:.2f} ~ {:.2f}'.format(np.mean(accs) * 100, np.std(accs) * 100))
  267. if __name__ == '__main__':
  268. main()