You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nas_bench_graph_example.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. Test file for nas on node classification
  3. AUTOGL_BACKEND=pyg python test/nas/node_classification.py
  4. AUTOGL_BACKEND=dgl python test/nas/node_classification.py
  5. TODO: make it a unit test file to test all the possible combinations
  6. """
  7. import os
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. from autogl.backend import DependentBackend
  11. if DependentBackend.is_dgl():
  12. from autogl.module.model.dgl import BaseAutoModel
  13. from dgl.data import CoraGraphDataset
  14. elif DependentBackend.is_pyg():
  15. from torch_geometric.datasets import Planetoid
  16. from autogl.module.model.pyg import BaseAutoModel
  17. from autogl.datasets import build_dataset_from_name
  18. import torch
  19. import torch.nn.functional as F
  20. #from autogl.module.nas.algorithm.agnn_rl import AGNNRL
  21. from autogl.module.nas.backend import bk_feat, bk_label
  22. from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
  23. from autogl.module.nas.estimator import BaseEstimator
  24. from autogl.module.train.evaluation import Acc
  25. import numpy as np
  26. from autogl.solver.utils import set_seed
  27. from autogl.module.nas.space import BaseSpace
  28. import typing as _typ
  29. from torch import nn
  30. from nas_bench_graph import light_read, gnn_list, gnn_list_proteins, Arch
  31. import pandas as pd
  32. import argparse
  33. import torch
  34. import os
  35. import os.path as osp
  36. # Define the search space in NAS-bench-graph
  37. class StrModule(nn.Module):
  38. def __init__(self, lambd):
  39. super().__init__()
  40. self.name = lambd
  41. def forward(self, *args, **kwargs):
  42. return self.name
  43. def __repr__(self):
  44. return "{}({})".format(self.__class__.__name__, self.name)
  45. class BenchSpace(BaseSpace):
  46. def __init__(
  47. self,
  48. hidden_dim: _typ.Optional[int] = 64,
  49. layer_number: _typ.Optional[int] = 2,
  50. dropout: _typ.Optional[float] = 0.9,
  51. input_dim: _typ.Optional[int] = None,
  52. output_dim: _typ.Optional[int] = None,
  53. ops_type = 0
  54. ):
  55. super().__init__()
  56. self.layer_number = layer_number
  57. self.hidden_dim = hidden_dim
  58. self.input_dim = input_dim
  59. self.output_dim = output_dim
  60. self.dropout = dropout
  61. self.ops_type=ops_type
  62. def instantiate(
  63. self,
  64. hidden_dim: _typ.Optional[int] = None,
  65. layer_number: _typ.Optional[int] = None,
  66. dropout: _typ.Optional[float] = None,
  67. input_dim: _typ.Optional[int] = None,
  68. output_dim: _typ.Optional[int] = None,
  69. ops_type=None
  70. ):
  71. super().instantiate()
  72. self.dropout = dropout or self.dropout
  73. self.hidden_dim = hidden_dim or self.hidden_dim
  74. self.layer_number = layer_number or self.layer_number
  75. self.input_dim = input_dim or self.input_dim
  76. self.output_dim = output_dim or self.output_dim
  77. self.ops_type = ops_type or self.ops_type
  78. self.ops = [gnn_list,gnn_list_proteins][self.ops_type]
  79. for layer in range(4):
  80. setattr(self,f"in{layer}",self.setInputChoice(layer,n_candidates=layer+1,n_chosen=1,return_mask=False,key=f"in{layer}"))
  81. setattr(self,f"op{layer}",self.setLayerChoice(layer,list(map(lambda x:StrModule(x),self.ops)),key=f"op{layer}"))
  82. self.dummy=nn.Linear(1,1)
  83. def forward(self, bench):
  84. lks = [getattr(self, "in" + str(i)).selected for i in range(4)]
  85. ops = [getattr(self, "op" + str(i)).name for i in range(4)]
  86. arch = Arch(lks, ops)
  87. h = arch.valid_hash()
  88. if h == "88888" or h==88888:
  89. return 0
  90. return bench[h]['perf']
  91. def parse_model(self, selection, device) -> BaseAutoModel:
  92. return self.wrap().fix(selection)
  93. # Define a new estimator which directly get performance from NAS-bench-graph instead of training the model
  94. class BenchEstimator(BaseEstimator):
  95. def __init__(self, data_name, loss_f="nll_loss", evaluation=[Acc()]):
  96. super().__init__(loss_f, evaluation)
  97. self.evaluation = evaluation
  98. self.bench=light_read(data_name)
  99. def infer(self, model: BaseSpace, dataset, mask="train"):
  100. perf=model(self.bench)
  101. return [perf],0
  102. # Run NAS with NAS-bench-graph
  103. def run(data_name='cora',algo='graphnas',num_epochs=50,ctrl_steps_aggregate=20,log_dir='./logs/tmp'):
  104. print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
  105. if DependentBackend.is_dgl():
  106. from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset
  107. else:
  108. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset
  109. di=2
  110. do=2
  111. dataset=None
  112. ops_type=data_name=='proteins'
  113. space = BenchSpace().cuda()
  114. space.instantiate(input_dim=di, output_dim=do,ops_type=ops_type)
  115. esti = BenchEstimator(data_name)
  116. if algo=='graphnas':
  117. algo = GraphNasRL(num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
  118. elif algo=='agnn':
  119. algo = AGNNRL(guide_type=1,num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
  120. else:
  121. assert False,f'Not implemented algo {algo}'
  122. model = algo.search(space, dataset, esti)
  123. result=esti.infer(model._model,None)[0][0]
  124. os.makedirs(log_dir,exist_ok=True)
  125. with open(osp.join(log_dir,f'log.txt'),'w') as f:
  126. f.write(str(result))
  127. import json
  128. archs=algo.allhist
  129. json.dump(archs,open(osp.join(log_dir,f'archs.json'),'w'))
  130. arch_strs=[str(x[1]) for x in archs]
  131. print(f'number of archs: {len(arch_strs)} ; number of unique archs : {len(set(arch_strs))}')
  132. scores=[-x[0] for x in archs] # accs
  133. idxs=np.argsort(scores) # increasing order
  134. with open(osp.join(log_dir,f'idx.txt'),'w') as f:
  135. f.write(str(idxs))
  136. return result
  137. # Run NAS with NAS-bench-graph for all provided datasets
  138. def run_all():
  139. data_names='arxiv citeseer computers cora cs photo physics proteins pubmed'.split()
  140. algos='graphnas agnn'.split()
  141. results=[]
  142. for data_name in data_names:
  143. for algo in algos:
  144. print(f'data {data_name} algo {algo}')
  145. # metric=run(data_name,algo,2,2)
  146. if data_name=='proteins':
  147. metric=run(data_name,algo,8,5)
  148. else:
  149. metric=run(data_name,algo,50,10)
  150. results.append([data_name,algo,metric])
  151. return results
  152. if __name__ == "__main__":
  153. # results=run_all()
  154. # df=pd.DataFrame(results,columns='data algo v'.split()).pivot_table(values='v',index='algo',columns='data')
  155. # print(df.to_string())
  156. parser = argparse.ArgumentParser()
  157. parser.add_argument('--data', type=str, default='cora', help='datasets')
  158. parser.add_argument('--algo', type=str, default='graphnas')
  159. parser.add_argument('--log_dir', type=str, default='./logs/')
  160. args = parser.parse_args()
  161. dname=args.data
  162. algo=args.algo
  163. log_dir= os.path.join(args.log_dir,f'{dname,algo}')
  164. if dname=='proteins':
  165. # 40 archs in total
  166. num_epochs=8
  167. ctrl_steps_aggregate=5
  168. else:
  169. # 500 archs in total
  170. num_epochs=50
  171. ctrl_steps_aggregate=10
  172. result=run(dname,algo,num_epochs,ctrl_steps_aggregate,log_dir)