diff --git a/examples/test_graph_nas_space.py b/examples/test_graph_nas_space.py new file mode 100644 index 0000000..0d60f3c --- /dev/null +++ b/examples/test_graph_nas_space.py @@ -0,0 +1,73 @@ +import sys +from nni.nas.pytorch.mutables import Mutable +sys.path.append('../') +from torch_geometric.nn import GCNConv +import torch +from autogl.datasets import build_dataset_from_name +from autogl.solver import AutoNodeClassifier +from autogl.module.train import NodeClassificationFullTrainer +from autogl.module.nas import Darts, OneShotEstimator +from autogl.module.nas.space.graph_nas import * +from autogl.module.train import Acc +from autogl.module.nas.algorithm.enas import Enas +from autogl.module.nas.algorithm.rl import * +from autogl.module.nas.estimator.one_shot import TrainEstimator +import logging +import numpy as np +from tqdm import tqdm +if __name__ == '__main__': + logging.getLogger().setLevel(logging.WARNING) + dataset = build_dataset_from_name('cora') + space=GraphNasNodeClassificationSpace(hidden_dim=16,search_act_con=True,layer_number=2) + space.instantiate(input_dim=dataset[0].x.shape[1], + output_dim=dataset.num_classes,) + estim=TrainEstimator() + # solver.fit(dataset) + # solver.get_leaderboard().show() + # out = solver.predict_proba() + + # print('acc on cora', Acc.evaluate(out, dataset[0].y[dataset[0].test_mask].detach().numpy())) + class Tmp: + def __init__(self,space): + self.model = space + self.nas_modules = [] + k2o = get_module_order(self.model) + replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) + replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) + self.nas_modules = sort_replaced_module(k2o, self.nas_modules) + + t=Tmp(space) + print(t.nas_modules) + nm=t.nas_modules + selection_range={} + for k,v in nm: + selection_range[k]=len(v) + ks=list(selection_range.keys()) + selections=[] + def dfs(selection,d): + if d>=len(ks): + selections.append(selection.copy()) + return + k=ks[d] + r=selection_range[k] + for i in range(r): + selection[k]=i + dfs(selection,d+1) + dfs({},0) + print(f'#selections {len(selections)}') + device=torch.device('cuda:0') + accs=[] + from datetime import datetime + timestamp=datetime.now().strftime('%m%d-%H-%M-%S') + log=open(f'acclog{timestamp}.txt','w') + with tqdm(selections) as bar: + for selection in bar: + arch=space.export(selection,device) + m,l=estim.infer(arch,dataset,'test') + bar.set_postfix(m=m,l=l.item()) + log.write(f'{arch}\n{selection}\n{m},{l}\n') + log.flush() + accs.append(m) + + np.save(f'space_acc{timestamp}',np.array(accs)) + print(f'max acc {np.max(accs)}') \ No newline at end of file