diff --git a/autogl/module/nas/algorithm/random_search.py b/autogl/module/nas/algorithm/random_search.py new file mode 100644 index 0000000..797c6dd --- /dev/null +++ b/autogl/module/nas/algorithm/random_search.py @@ -0,0 +1,81 @@ +import copy +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import BaseNAS +from ..space import BaseSpace +from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module +from nni.nas.pytorch.fixed import apply_fixed_architecture +from tqdm import tqdm +_logger = logging.getLogger(__name__) +from .rl import PathSamplingLayerChoice,PathSamplingInputChoice +import numpy as np +class RSBox: + '''get selection space for model `space` ''' + def __init__(self,space): + self.model = space + self.nas_modules = [] + k2o = get_module_order(self.model) + replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) + replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) + self.nas_modules = sort_replaced_module(k2o, self.nas_modules) + nm=self.nas_modules + selection_range={} + for k,v in nm: + selection_range[k]=len(v) + self.selection_dict=selection_range + + + space_size=np.prod(list(selection_range.values())) + print(f'Using random search Box. Total space size: {space_size}') + print('Searching Space:',selection_range) + def export(self): + return self.selection_dict #{k:v}, means action ranges 0 to v-1 for layer named k + def sample(self): + # uniformly sample + selection={} + sdict=self.export() + for k,v in sdict.items(): + selection[k]=np.random.choice(range(v)) + return selection + +class RandomSearch(BaseNAS): + ''' + uniformly search + ''' + def __init__(self, device='cuda',num_epochs=400,disable_progress=False,*args,**kwargs): + super().__init__(device) + self.num_epochs=num_epochs + self.disable_progress=disable_progress + def search(self, space: BaseSpace, dset, estimator): + self.estimator=estimator + self.dataset=dset + self.space=space + self.box=RSBox(self.space) + arch_perfs=[] + cache={} + with tqdm(range(self.num_epochs),disable=self.disable_progress) as bar: + for i in bar: + selection=self.export() + # print(selection) + vec=tuple(list(selection.values())) + if vec not in cache: + self.arch=space.export(selection,self.device) + metric,loss=self._infer(mask='val') + arch_perfs.append([metric,selection]) + cache[vec]=metric + bar.set_postfix(acc=metric,max_acc=max(cache.values())) + selection=arch_perfs[np.argmax([x[0] for x in arch_perfs])][1] + arch=space.export(selection,self.device) + return arch + + def export(self): + arch=self.box.sample() + return arch + + def _infer(self,mask='train'): + metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) + return metric, loss diff --git a/autogl/module/nas/algorithm/rl.py b/autogl/module/nas/algorithm/rl.py index 9d38480..9f152a9 100644 --- a/autogl/module/nas/algorithm/rl.py +++ b/autogl/module/nas/algorithm/rl.py @@ -11,6 +11,8 @@ from ..space import BaseSpace from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module from nni.nas.pytorch.fixed import apply_fixed_architecture from tqdm import tqdm +from datetime import datetime +import numpy as np _logger = logging.getLogger(__name__) def _get_mask(sampled, total): multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] @@ -229,7 +231,7 @@ class ReinforceController(nn.Module): class RL(BaseNAS): """ - ENAS trainer. + RL in GraphNas. Parameters ---------- @@ -293,7 +295,7 @@ class RL(BaseNAS): self.n_warmup=n_warmup self.model_lr = model_lr self.model_wd = model_wd - self.log=open('log.txt','w') + self.log=open('../tmp/log.txt','w') def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset#.to(self.device) @@ -318,16 +320,6 @@ class RL(BaseNAS): with tqdm(range(self.num_epochs)) as bar: for i in bar: l2=self._train_controller(i) - - # try: - # l2=self._train_controller(i) - # except Exception as e: - # print(e) - # nm=self.nas_modules - # for i in range(len(nm)): - # print(nm[i][1].sampled) - # # import pdb - # # pdb.set_trace() bar.set_postfix(reward_controller=l2) selection=self.export() @@ -382,3 +374,176 @@ class RL(BaseNAS): def _infer(self,mask='train'): metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) return metric, loss + +class GraphNasRL(BaseNAS): + """ + RL in GraphNas. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + reward_function : callable + Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. + optimizer : Optimizer + The optimizer used for optimizing the model. + num_epochs : int + Number of epochs planned for training. + dataset : Dataset + Dataset for training. Will be split for training weights and architecture weights. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + grad_clip : float + Gradient clipping. Set to 0 to disable. Default: 5. + entropy_weight : float + Weight of sample entropy loss. + skip_weight : float + Weight of skip penalty loss. + baseline_decay : float + Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. + ctrl_lr : float + Learning rate for RL controller. + ctrl_steps_aggregate : int + Number of steps that will be aggregated into one mini-batch for RL controller. + ctrl_steps : int + Number of mini-batches for each epoch of RL controller learning. + ctrl_kwargs : dict + Optional kwargs that will be passed to :class:`ReinforceController`. + """ + + def __init__(self, device='cuda', workers=4,log_frequency=None, + grad_clip=5., entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95, + ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None,n_warmup=100,model_lr=5e-3,model_wd=5e-4,topk=5,*args,**kwargs): + super().__init__(device) + self.device=device + self.num_epochs = kwargs.get("num_epochs", 10) + self.workers = workers + self.log_frequency = log_frequency + self.entropy_weight = entropy_weight + self.skip_weight = skip_weight + self.baseline_decay = baseline_decay + self.ctrl_steps_aggregate = ctrl_steps_aggregate + self.grad_clip = grad_clip + self.workers = workers + self.ctrl_kwargs=ctrl_kwargs + self.ctrl_lr=ctrl_lr + self.n_warmup=n_warmup + self.model_lr = model_lr + self.model_wd = model_wd + timestamp=datetime.now().strftime('%m%d-%H-%M-%S') + self.log=open(f'../tmp/log-{timestamp}.txt','w') + self.hist=[] + self.topk=topk + def search(self, space: BaseSpace, dset, estimator): + self.model = space + self.dataset = dset#.to(self.device) + self.estimator = estimator + # replace choice + self.nas_modules = [] + + k2o = get_module_order(self.model) + replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) + replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) + self.nas_modules = sort_replaced_module(k2o, self.nas_modules) + + # to device + self.model = self.model.to(self.device) + # fields + self.nas_fields = [ReinforceField(name, len(module), + isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) + for name, module in self.nas_modules] + self.controller = ReinforceController(self.nas_fields,lstm_size=100,temperature=5.0,tanh_constant=2.5, **(self.ctrl_kwargs or {})) + self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) + # train + with tqdm(range(self.num_epochs)) as bar: + for i in bar: + l2=self._train_controller(i) + bar.set_postfix(reward_controller=l2) + + # selection=self.export() + + selections=[x[1] for x in self.hist] + candidiate_accs=[-x[0] for x in self.hist] + print('candidiate accuracies',candidiate_accs) + selection=self._choose_best(selections) + arch=space.export(selection,self.device) + print(selection,arch) + return arch + def _choose_best(self,selections): + # graphnas use top 5 models, can evaluate 20 times epoch and choose the best. + results=[] + for selection in selections: + accs=[] + for i in tqdm(range(20)): + self.arch=self.model.export(selection,device=self.device) + metric,loss=self._infer(mask='val') + accs.append(metric) + result=np.mean(accs) + print('selection {} \n acc {:.4f} +- {:.4f}'.format(selection,np.mean(accs),np.std(accs)/np.sqrt(20))) + results.append(result) + best_selection=selections[np.argmax(results)] + return best_selection + + def _train_controller(self, epoch): + self.model.eval() + self.controller.train() + self.ctrl_optim.zero_grad() + rewards=[] + baseline=None + # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000) + with tqdm(range(self.ctrl_steps_aggregate)) as bar: + for ctrl_step in bar: + self._resample() + metric,loss=self._infer(mask='val') + + # bar.set_postfix(acc=metric,loss=loss.item()) + self.log.write(f'{self.arch}\n{self.selection}\n{metric},{loss}\n') + self.log.flush() + # diff: not do reward shaping as in graphnas code + reward =metric + self.hist.append([-metric,self.selection]) + if len(self.hist)>self.topk: + self.hist.sort(key=lambda x:x[0]) + self.hist.pop() + rewards.append(reward) + + if self.entropy_weight: + reward += self.entropy_weight * self.controller.sample_entropy.item() + + if not baseline: + baseline= reward + else: + baseline = baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + + loss = self.controller.sample_log_prob * (reward - baseline) + self.ctrl_optim.zero_grad() + loss.backward() + + self.ctrl_optim.step() + + bar.set_postfix(acc=metric,max_acc=max(rewards)) + return sum(rewards)/len(rewards) + + def _resample(self): + result = self.controller.resample() + self.arch=self.model.export(result,device=self.device) + self.selection=result + + def export(self): + self.controller.eval() + with torch.no_grad(): + return self.controller.resample() + + def _infer(self,mask='train'): + metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) + return metric, loss \ No newline at end of file diff --git a/autogl/module/nas/estimator/one_shot.py b/autogl/module/nas/estimator/one_shot.py index 4964947..9fc33be 100644 --- a/autogl/module/nas/estimator/one_shot.py +++ b/autogl/module/nas/estimator/one_shot.py @@ -31,9 +31,9 @@ class TrainEstimator(BaseEstimator): self.trainer=NodeClassificationFullTrainer( model=model, optimizer=torch.optim.Adam, - lr=0.01, - max_epoch=200, - early_stopping_round=200, + lr=0.005, + max_epoch=300, + early_stopping_round=30, weight_decay=5e-4, device="auto", init=False, diff --git a/autogl/module/nas/space/graph_nas_macro.py b/autogl/module/nas/space/graph_nas_macro.py index a59e8bb..963c13a 100644 --- a/autogl/module/nas/space/graph_nas_macro.py +++ b/autogl/module/nas/space/graph_nas_macro.py @@ -392,7 +392,7 @@ class GraphNasMacroNodeClfSpace(BaseSpace): self, hidden_dim: _typ.Optional[int] = 64, layer_number: _typ.Optional[int] = 2, - dropout: _typ.Optional[float] = 0.9, + dropout: _typ.Optional[float] = 0.6, input_dim: _typ.Optional[int] = None, output_dim: _typ.Optional[int] = None, ops: _typ.Tuple = None, diff --git a/examples/test_graph_nas_rl.py b/examples/test_graph_nas_rl.py index 439888a..fbe1052 100644 --- a/examples/test_graph_nas_rl.py +++ b/examples/test_graph_nas_rl.py @@ -10,8 +10,9 @@ from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClfSpace from autogl.module.train import Acc from autogl.module.nas.algorithm.enas import Enas -from autogl.module.nas.algorithm.rl import RL +from autogl.module.nas.algorithm.rl import RL,GraphNasRL from autogl.module.nas.estimator.one_shot import TrainEstimator +from autogl.module.nas.algorithm.random_search import RandomSearch import logging if __name__ == '__main__': logging.getLogger().setLevel(logging.WARNING) @@ -23,16 +24,17 @@ if __name__ == '__main__': ensemble_module=None, default_trainer=NodeClassificationFullTrainer( optimizer=torch.optim.Adam, - lr=0.01, - max_epoch=200, - early_stopping_round=200, + lr=0.005, + max_epoch=300, + early_stopping_round=20, weight_decay=5e-4, device="auto", init=False, feval=['acc'], loss="nll_loss", lr_scheduler_type=None,), - nas_algorithms=[RL(num_epochs=400)], + # nas_algorithms=[RL(num_epochs=400)], + nas_algorithms=[GraphNasRL(num_epochs=20)], #nas_algorithms=[Darts(num_epochs=200)], nas_spaces=[GraphNasMacroNodeClfSpace(hidden_dim=16,search_act_con=True,layer_number=2)], nas_estimators=[TrainEstimator()]