You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nas_bench_graph_example.py 7.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. Test file for nas on node classification
  3. AUTOGL_BACKEND=pyg python test/nas/node_classification.py
  4. AUTOGL_BACKEND=dgl python test/nas/node_classification.py
  5. TODO: make it a unit test file to test all the possible combinations
  6. """
  7. import os
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. from autogl.backend import DependentBackend
  11. if DependentBackend.is_dgl():
  12. from autogl.module.model.dgl import BaseAutoModel
  13. from dgl.data import CoraGraphDataset
  14. elif DependentBackend.is_pyg():
  15. from torch_geometric.datasets import Planetoid
  16. from autogl.module.model.pyg import BaseAutoModel
  17. from autogl.datasets import build_dataset_from_name
  18. import torch
  19. from torch import nn
  20. import torch.nn.functional as F
  21. #from autogl.module.nas.algorithm.agnn_rl import AGNNRL
  22. from autogl.module.nas.backend import bk_feat, bk_label
  23. from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
  24. from autogl.module.nas.estimator import BaseEstimator
  25. from autogl.module.train.evaluation import Acc
  26. import numpy as np
  27. from autogl.solver.utils import set_seed
  28. from autogl.module.nas.space import BaseSpace
  29. import typing as _typ
  30. from nas_bench_graph import light_read, gnn_list, gnn_list_proteins, Arch
  31. import pandas as pd
  32. import argparse
  33. import os
  34. import os.path as osp
  35. # Define the search space in NAS-bench-graph
  36. class StrModule(nn.Module):
  37. def __init__(self, lambd):
  38. super().__init__()
  39. self.name = lambd
  40. def forward(self, *args, **kwargs):
  41. return self.name
  42. def __repr__(self):
  43. return "{}({})".format(self.__class__.__name__, self.name)
  44. class BenchSpace(BaseSpace):
  45. def __init__(
  46. self,
  47. hidden_dim: _typ.Optional[int] = 64,
  48. layer_number: _typ.Optional[int] = 2,
  49. dropout: _typ.Optional[float] = 0.9,
  50. input_dim: _typ.Optional[int] = None,
  51. output_dim: _typ.Optional[int] = None,
  52. ops_type = 0
  53. ):
  54. super().__init__()
  55. self.layer_number = layer_number
  56. self.hidden_dim = hidden_dim
  57. self.input_dim = input_dim
  58. self.output_dim = output_dim
  59. self.dropout = dropout
  60. self.ops_type=ops_type
  61. def instantiate(
  62. self,
  63. hidden_dim: _typ.Optional[int] = None,
  64. layer_number: _typ.Optional[int] = None,
  65. dropout: _typ.Optional[float] = None,
  66. input_dim: _typ.Optional[int] = None,
  67. output_dim: _typ.Optional[int] = None,
  68. ops_type=None
  69. ):
  70. super().instantiate()
  71. self.dropout = dropout or self.dropout
  72. self.hidden_dim = hidden_dim or self.hidden_dim
  73. self.layer_number = layer_number or self.layer_number
  74. self.input_dim = input_dim or self.input_dim
  75. self.output_dim = output_dim or self.output_dim
  76. self.ops_type = ops_type or self.ops_type
  77. self.ops = [gnn_list,gnn_list_proteins][self.ops_type]
  78. for layer in range(4):
  79. setattr(self,f"in{layer}",self.setInputChoice(layer,n_candidates=layer+1,n_chosen=1,return_mask=False,key=f"in{layer}"))
  80. setattr(self,f"op{layer}",self.setLayerChoice(layer,list(map(lambda x:StrModule(x),self.ops)),key=f"op{layer}"))
  81. self.dummy=nn.Linear(1,1)
  82. def forward(self, bench):
  83. lks = [getattr(self, "in" + str(i)).selected for i in range(4)]
  84. ops = [getattr(self, "op" + str(i)).name for i in range(4)]
  85. arch = Arch(lks, ops)
  86. h = arch.valid_hash()
  87. if h == "88888" or h==88888:
  88. return 0
  89. return bench[h]['perf']
  90. def parse_model(self, selection, device) -> BaseAutoModel:
  91. return self.wrap().fix(selection)
  92. # Define a new estimator which directly get performance from NAS-bench-graph instead of training the model
  93. class BenchEstimator(BaseEstimator):
  94. def __init__(self, data_name, loss_f="nll_loss", evaluation=[Acc()]):
  95. super().__init__(loss_f, evaluation)
  96. self.evaluation = evaluation
  97. self.bench=light_read(data_name)
  98. def infer(self, model: BaseSpace, dataset, mask="train"):
  99. perf=model(self.bench)
  100. return [perf],0
  101. # Run NAS with NAS-bench-graph
  102. def run(data_name='cora',algo='graphnas',num_epochs=50,ctrl_steps_aggregate=20,log_dir='./logs/tmp'):
  103. print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
  104. if DependentBackend.is_dgl():
  105. from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset
  106. else:
  107. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset
  108. di=2
  109. do=2
  110. dataset=None
  111. ops_type=data_name=='proteins'
  112. space = BenchSpace().cuda()
  113. space.instantiate(input_dim=di, output_dim=do,ops_type=ops_type)
  114. esti = BenchEstimator(data_name)
  115. if algo=='graphnas':
  116. algo = GraphNasRL(num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
  117. elif algo=='agnn':
  118. algo = AGNNRL(guide_type=1,num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
  119. else:
  120. assert False,f'Not implemented algo {algo}'
  121. model = algo.search(space, dataset, esti)
  122. result=esti.infer(model._model,None)[0][0]
  123. os.makedirs(log_dir,exist_ok=True)
  124. with open(osp.join(log_dir,f'log.txt'),'w') as f:
  125. f.write(str(result))
  126. import json
  127. archs=algo.allhist
  128. json.dump(archs,open(osp.join(log_dir,f'archs.json'),'w'))
  129. arch_strs=[str(x[1]) for x in archs]
  130. print(f'number of archs: {len(arch_strs)} ; number of unique archs : {len(set(arch_strs))}')
  131. scores=[-x[0] for x in archs] # accs
  132. idxs=np.argsort(scores) # increasing order
  133. with open(osp.join(log_dir,f'idx.txt'),'w') as f:
  134. f.write(str(idxs))
  135. return result
  136. # Run NAS with NAS-bench-graph for all provided datasets
  137. def run_all():
  138. data_names='arxiv citeseer computers cora cs photo physics proteins pubmed'.split()
  139. algos='graphnas agnn'.split()
  140. results=[]
  141. for data_name in data_names:
  142. for algo in algos:
  143. print(f'data {data_name} algo {algo}')
  144. # metric=run(data_name,algo,2,2)
  145. if data_name=='proteins':
  146. metric=run(data_name,algo,8,5)
  147. else:
  148. metric=run(data_name,algo,50,10)
  149. results.append([data_name,algo,metric])
  150. return results
  151. if __name__ == "__main__":
  152. # results=run_all()
  153. # df=pd.DataFrame(results,columns='data algo v'.split()).pivot_table(values='v',index='algo',columns='data')
  154. # print(df.to_string())
  155. parser = argparse.ArgumentParser()
  156. parser.add_argument('--data', type=str, default='cora', help='datasets')
  157. parser.add_argument('--algo', type=str, default='graphnas')
  158. parser.add_argument('--log_dir', type=str, default='./logs/')
  159. args = parser.parse_args()
  160. dname=args.data
  161. algo=args.algo
  162. log_dir= os.path.join(args.log_dir,f'{dname,algo}')
  163. if dname=='proteins':
  164. # 40 archs in total
  165. num_epochs=8
  166. ctrl_steps_aggregate=5
  167. else:
  168. # 500 archs in total
  169. num_epochs=50
  170. ctrl_steps_aggregate=10
  171. result=run(dname,algo,num_epochs,ctrl_steps_aggregate,log_dir)