|
- """
- Test file for nas on node classification
-
- AUTOGL_BACKEND=pyg python test/nas/node_classification.py
- AUTOGL_BACKEND=dgl python test/nas/node_classification.py
-
- TODO: make it a unit test file to test all the possible combinations
- """
-
- import os
- import logging
-
- logging.basicConfig(level=logging.INFO)
-
- from autogl.backend import DependentBackend
-
- if DependentBackend.is_dgl():
- from autogl.module.model.dgl import BaseAutoModel
- from dgl.data import CoraGraphDataset
- elif DependentBackend.is_pyg():
- from torch_geometric.datasets import Planetoid
- from autogl.module.model.pyg import BaseAutoModel
- from autogl.datasets import build_dataset_from_name
- import torch
- import torch.nn.functional as F
- from autogl.module.nas.space.single_path import SinglePathNodeClassificationSpace
- from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
- from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClassificationSpace
- from autogl.module.nas.estimator.one_shot import OneShotEstimator
- from autogl.module.nas.estimator.train_scratch import TrainEstimator
- from autogl.module.nas.algorithm.agnn_rl import AGNNRL
- from autogl.module.nas.space.autoattend import AutoAttendNodeClassificationSpace
- from autogl.module.nas.backend import bk_feat, bk_label
- from autogl.module.nas.algorithm import Darts, RL, GraphNasRL, Enas, RandomSearch,Spos
- import numpy as np
- from autogl.solver.utils import set_seed
-
- set_seed(202106)
- from autogl.module.nas.space import BaseSpace
- import typing as _typ
- from torch import nn
- class StrModule(nn.Module):
- def __init__(self, lambd):
- super().__init__()
- self.name = lambd
-
- def forward(self, *args, **kwargs):
- return self.name
-
- def __repr__(self):
- return "{}({})".format(self.__class__.__name__, self.name)
-
-
- gnn_list_proteins = [
- "gcn", # GCN
- "cheb", # chebnet
- "arma",
- "fc", # skip connection
- "skip" # skip connection
- ]
-
- gnn_list = [
- "gat", # GAT with 2 heads
- "gcn", # GCN
- "gin", # GIN
- "cheb", # chebnet
- "sage", # sage
- "arma",
- "graph",
- "fc", # skip connection
- "skip" # skip connection
- ]
-
-
- class Arch:
- def __init__(self, lk=None, op=None):
- self.link = lk
- self.ops = op
-
- # def random_arch(self):
- # self.ops = []
- # self.link = random.choice(link_list)
- # for i in self.link:
- # self.ops.append(random.choice(gnn_list))
-
- def hash_arch(self, use_proteins = False):
- lk = self.link
- op = self.ops
- if use_proteins:
- gnn_g = {name: i for i, name in enumerate(gnn_list_proteins)}
- b = len(gnn_list_proteins) + 1
- else:
- gnn_g = {name: i for i, name in enumerate(gnn_list)}
- b = len(gnn_list) + 1
- if lk == [0,0,0,0]:
- lk_hash = 0
- elif lk == [0,0,0,1]:
- lk_hash = 1
- elif lk == [0,0,1,1]:
- lk_hash = 2
- elif lk == [0,0,1,2]:
- lk_hash = 3
- elif lk == [0,0,1,3]:
- lk_hash = 4
- elif lk == [0,1,1,1]:
- lk_hash = 5
- elif lk == [0,1,1,2]:
- lk_hash = 6
- elif lk == [0,1,2,2]:
- lk_hash = 7
- elif lk == [0,1,2,3]:
- lk_hash = 8
-
- for i in op:
- lk_hash = lk_hash * b + gnn_g[i]
- return lk_hash
-
- def regularize(self):
- lk = self.link[:]
- ops = self.ops[:]
- if lk == [0,0,0,2]:
- lk = [0,0,0,1]
- ops = [ops[1], ops[0], ops[2], ops[3]]
- elif lk == [0,0,0,3]:
- lk = [0,0,0,1]
- ops = [ops[2], ops[0], ops[1], ops[3]]
- elif lk == [0,0,1,0]:
- lk = [0,0,0,1]
- ops = [ops[0], ops[1], ops[3], ops[2]]
- elif lk == [0,0,2,0]:
- lk = [0,0,0,1]
- ops = [ops[1], ops[0], ops[3], ops[2]]
- elif lk == [0,0,2,1]:
- lk = [0,0,1,2]
- ops = [ops[1], ops[0], ops[2], ops[3]]
- elif lk == [0,0,2,2]:
- lk = [0,0,1,1]
- ops = [ops[1], ops[0], ops[2], ops[3]]
- elif lk == [0,0,2,3]:
- lk = [0,0,1,3]
- ops = [ops[1], ops[0], ops[2], ops[3]]
- elif lk == [0,1,0,0]:
- lk = [0,0,0,1]
- ops = [ops[0], ops[2], ops[3], ops[1]]
- elif lk == [0,1,0,1]:
- lk = [0,0,1,1]
- ops = [ops[0], ops[2], ops[1], ops[3]]
- elif lk == [0,1,0,2]:
- lk = [0,0,1,3]
- ops = [ops[0], ops[2], ops[1], ops[3]]
- elif lk == [0,1,0,3]:
- lk = [0,0,1,2]
- ops = [ops[0], ops[2], ops[1], ops[3]]
- elif lk == [0,1,1,0]:
- lk = [0,0,1,1]
- ops = [ops[0], ops[3], ops[1], ops[2]]
- elif lk == [0,1,1,3]:
- lk = [0,1,1,2]
- ops = [ops[0], ops[2], ops[1], ops[3]]
- elif lk == [0,1,2,0]:
- lk = [0,0,1,3]
- ops = [ops[0], ops[3], ops[1], ops[2]]
- elif lk == [0,1,2,1]:
- lk = [0,1,1,2]
- ops = [ops[0], ops[1], ops[3], ops[2]]
- return Arch(lk, ops)
-
- def equalpart_sort(self):
- lk = self.link
- op = self.ops
- ops = op[:]
- def part_sort(ids, ops):
- gnn_g = {name: i for i, name in enumerate(gnn_list)}
- opli = [gnn_g[ops[i]] for i in ids]
- opli.sort()
- for posid, opid in zip(ids, opli):
- ops[posid] = gnn_list[opid]
- return ops
-
- def sort0012(ops):
- gnn_g = {name: i for i, name in enumerate(gnn_list)}
- if gnn_g[op[0]] > gnn_g[op[1]] or op[0] == op[1] and gnn_g[op[2]] > gnn_g[op[3]]:
- ops = [ops[1], ops[0], ops[3], ops[2]]
- return ops
-
- if lk == [0,0,0,0]:
- ids = [0,1,2,3]
- elif lk == [0,0,0,1]:
- ids = [1,2]
- elif lk == [0,0,1,1]:
- ids = [2,3]
- elif lk == [0,0,1,2]:
- ids = None
- ops = sort0012(ops)
- elif lk == [0,1,1,1]:
- ids = [1,2,3]
- elif lk == [0,1,2,2]:
- ids = [2,3]
- else:
- ids = None
-
- if ids:
- ops = part_sort(ids, ops)
-
- self.ops = ops
-
- def move_skip_op(self):
- link = self.link[:]
- ops = self.ops[:]
- def move_one(k, link, ops):
- ops = [ops[k]] + ops[:k] + ops[k + 1:]
- for i, father in enumerate(link):
- if father == k + 1:
- link[i] = link[k]
- if father <= k:
- link[i] = link[i] + 1
- link = [0] + link[:k] + link[k + 1:]
- return link, ops
-
- def check_dim(k, link, ops):
- # check if a dimension is original dimension
- while k > -1:
- if ops[k] != 'skip':
- return False
- k = link[k] - 1
- return True
-
- for i in range(len(link)):
- if ops[i] != 'skip':
- continue
- son = False
- brother = False
- for j, fa in enumerate(link):
- if fa == i + 1:
- son = True
- elif j != i and fa == link[i]:
- brother = True
- if son or not brother or check_dim(i, link, ops) and not son:
- link, ops = move_one(i, link, ops)
-
- if link == [0,1,2,1]:
- link = [0,1,1,2]
- ops = ops[:2] + [ops[3], ops[2]]
- elif link == [0,1,1,3]:
- link = [0,1,1,2]
- ops = [ops[0], ops[2], ops[1], ops[3]]
-
- #if link not in link_list:
- # print(lk, link)
-
- self.link = link
- self.ops = ops
-
- def valid_hash(self):
- b = self.regularize()
- b.move_skip_op()
- b.equalpart_sort()
- return b.hash_arch()
-
- def check_isomorph(self):
- link, ops = self.link, self.ops
- linkm = link[:]
- opsm = ops[:]
- self.move_skip_op()
- self.equalpart_sort()
- #print(self.link, self.ops)
- return linkm == self.link and opsm == self.ops
-
- import nni
- def map_value(l, label):
- return nni.retiarii.nn.pytorch.ValueChoice(l, label = label)
- class BenchSpace(BaseSpace):
- def __init__(
- self,
- hidden_dim: _typ.Optional[int] = 64,
- layer_number: _typ.Optional[int] = 2,
- dropout: _typ.Optional[float] = 0.9,
- input_dim: _typ.Optional[int] = None,
- output_dim: _typ.Optional[int] = None,
- ops_type = 0
- ):
- super().__init__()
- self.layer_number = layer_number
- self.hidden_dim = hidden_dim
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.dropout = dropout
- self.ops_type=ops_type
-
- def instantiate(
- self,
- hidden_dim: _typ.Optional[int] = None,
- layer_number: _typ.Optional[int] = None,
- dropout: _typ.Optional[float] = None,
- input_dim: _typ.Optional[int] = None,
- output_dim: _typ.Optional[int] = None,
- ops_type=None
- ):
- super().instantiate()
- self.dropout = dropout or self.dropout
- self.hidden_dim = hidden_dim or self.hidden_dim
- self.layer_number = layer_number or self.layer_number
- self.input_dim = input_dim or self.input_dim
- self.output_dim = output_dim or self.output_dim
- self.ops_type = ops_type or self.ops_type
- self.ops = [gnn_list,gnn_list_proteins][self.ops_type]
- for layer in range(4):
- setattr(self,f"in{layer}",self.setInputChoice(layer,n_candidates=layer+1,n_chosen=1,return_mask=False,key=f"in{layer}"))
- setattr(self,f"op{layer}",self.setLayerChoice(layer,list(map(lambda x:StrModule(x),self.ops)),key=f"op{layer}"))
- self.dummy=nn.Linear(1,1)
-
- def forward(self, bench):
- lks = [getattr(self, "in" + str(i)).selected for i in range(4)]
- ops = [getattr(self, "op" + str(i)).name for i in range(4)]
- arch = Arch(lks, ops)
- h = arch.valid_hash()
- if h == "88888" or h==88888:
- return 0
- return bench[h]['perf']
-
- def parse_model(self, selection, device) -> BaseAutoModel:
- return self.wrap().fix(selection)
-
- import os.path as osp
- bench_path='/DATA/DATANAS1/zzy/bench/light'
- import pickle
-
- def light_read(dname):
- f = open(osp.join(bench_path,f"{dname}.bench"), "rb")
- bench = pickle.load(f)
- f.close()
- return bench
-
- from autogl.module.nas.estimator import BaseEstimator
- from autogl.module.train.evaluation import Acc
- class BenchEstimator(BaseEstimator):
- def __init__(self, data_name, loss_f="nll_loss", evaluation=[Acc()]):
- super().__init__(loss_f, evaluation)
- self.evaluation = evaluation
- self.bench=light_read(data_name)
- def infer(self, model: BaseSpace, dataset, mask="train"):
- perf=model(self.bench)
- return [perf],0
-
- def run(data_name='cora',algo='graphnas',num_epochs=50,ctrl_steps_aggregate=20,log_dir='./logs/tmp'):
- print("Testing backend: {}".format("dgl" if DependentBackend.is_dgl() else "pyg"))
- if DependentBackend.is_dgl():
- from autogl.datasets.utils.conversion._to_dgl_dataset import to_dgl_dataset as convert_dataset
- else:
- from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset as convert_dataset
-
- # dataset = build_dataset_from_name('cora')
- # dataset = convert_dataset(dataset)
- # data = dataset[0]
-
- # di = bk_feat(data).shape[1]
- # do = len(np.unique(bk_label(data)))
-
- di=2
- do=2
- dataset=None
-
- ops_type=data_name=='proteins'
-
- space = BenchSpace().cuda()
- space.instantiate(input_dim=di, output_dim=do,ops_type=ops_type)
- esti = BenchEstimator(data_name)
- if algo=='graphnas':
- algo = GraphNasRL(num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
- elif algo=='agnn':
- algo = AGNNRL(guide_type=1,num_epochs=num_epochs,ctrl_steps_aggregate=ctrl_steps_aggregate)
- else:
- assert False,f'Not implemented algo {algo}'
- model = algo.search(space, dataset, esti)
- result=esti.infer(model._model,None)[0][0]
-
- os.makedirs(log_dir,exist_ok=True)
- with open(osp.join(log_dir,f'log.txt'),'w') as f:
- f.write(str(result))
-
- import json
- archs=algo.allhist
- json.dump(archs,open(osp.join(log_dir,f'archs.json'),'w'))
-
- arch_strs=[str(x[1]) for x in archs]
- print(f'number of archs: {len(arch_strs)} ; number of unique archs : {len(set(arch_strs))}')
-
- scores=[-x[0] for x in archs] # accs
- idxs=np.argsort(scores) # increasing order
- with open(osp.join(log_dir,f'idx.txt'),'w') as f:
- f.write(str(idxs))
- return result
-
- def run_all():
- data_names='arxiv citeseer computers cora cs photo physics proteins pubmed'.split()
- algos='graphnas agnn'.split()
- results=[]
- for data_name in data_names:
- for algo in algos:
- print(f'data {data_name} algo {algo}')
- # metric=run(data_name,algo,2,2)
- if data_name=='proteins':
- metric=run(data_name,algo,8,5)
- else:
- metric=run(data_name,algo,50,10)
- results.append([data_name,algo,metric])
- return results
- import pandas as pd
- import argparse
- import torch
- import os
-
-
-
- if __name__ == "__main__":
- # results=run_all()
- # df=pd.DataFrame(results,columns='data algo v'.split()).pivot_table(values='v',index='algo',columns='data')
- # print(df.to_string())
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--data', type=str, default='cora', help='datasets')
- parser.add_argument('--algo', type=str, default='agnn')
- parser.add_argument('--log_dir', type=str, default='./logs/')
-
- args = parser.parse_args()
- dname=args.data
- algo=args.algo
- log_dir= os.path.join(args.log_dir,f'{dname,algo}')
- if dname=='proteins':
- # 40 archs in total
- num_epochs=8
- ctrl_steps_aggregate=5
- else:
- # 500 archs in total
- num_epochs=50
- ctrl_steps_aggregate=10
- result=run(dname,algo,num_epochs,ctrl_steps_aggregate,log_dir)
-
-
-
|