diff --git a/autogl/module/nas/algorithm/rl.py b/autogl/module/nas/algorithm/rl.py index 9d38480..3f011c0 100644 --- a/autogl/module/nas/algorithm/rl.py +++ b/autogl/module/nas/algorithm/rl.py @@ -11,6 +11,8 @@ from ..space import BaseSpace from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module from nni.nas.pytorch.fixed import apply_fixed_architecture from tqdm import tqdm +from datetime import datetime + _logger = logging.getLogger(__name__) def _get_mask(sampled, total): multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] @@ -229,7 +231,7 @@ class ReinforceController(nn.Module): class RL(BaseNAS): """ - ENAS trainer. + RL in GraphNas. Parameters ---------- @@ -293,7 +295,7 @@ class RL(BaseNAS): self.n_warmup=n_warmup self.model_lr = model_lr self.model_wd = model_wd - self.log=open('log.txt','w') + self.log=open('../tmp/log.txt','w') def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset#.to(self.device) @@ -318,16 +320,6 @@ class RL(BaseNAS): with tqdm(range(self.num_epochs)) as bar: for i in bar: l2=self._train_controller(i) - - # try: - # l2=self._train_controller(i) - # except Exception as e: - # print(e) - # nm=self.nas_modules - # for i in range(len(nm)): - # print(nm[i][1].sampled) - # # import pdb - # # pdb.set_trace() bar.set_postfix(reward_controller=l2) selection=self.export() @@ -382,3 +374,150 @@ class RL(BaseNAS): def _infer(self,mask='train'): metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) return metric, loss + + +class GraphNasRL(BaseNAS): + """ + RL in GraphNas. + + Parameters + ---------- + model : nn.Module + PyTorch model to be trained. + loss : callable + Receives logits and ground truth label, return a loss tensor. + metrics : callable + Receives logits and ground truth label, return a dict of metrics. + reward_function : callable + Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. + optimizer : Optimizer + The optimizer used for optimizing the model. + num_epochs : int + Number of epochs planned for training. + dataset : Dataset + Dataset for training. Will be split for training weights and architecture weights. + batch_size : int + Batch size. + workers : int + Workers for data loading. + device : torch.device + ``torch.device("cpu")`` or ``torch.device("cuda")``. + log_frequency : int + Step count per logging. + grad_clip : float + Gradient clipping. Set to 0 to disable. Default: 5. + entropy_weight : float + Weight of sample entropy loss. + skip_weight : float + Weight of skip penalty loss. + baseline_decay : float + Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. + ctrl_lr : float + Learning rate for RL controller. + ctrl_steps_aggregate : int + Number of steps that will be aggregated into one mini-batch for RL controller. + ctrl_steps : int + Number of mini-batches for each epoch of RL controller learning. + ctrl_kwargs : dict + Optional kwargs that will be passed to :class:`ReinforceController`. + """ + + def __init__(self, device='cuda', workers=4,log_frequency=None, + grad_clip=5., entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95, + ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None,n_warmup=100,model_lr=5e-3,model_wd=5e-4,*args,**kwargs): + super().__init__(device) + self.device=device + self.num_epochs = kwargs.get("num_epochs", 10) + self.workers = workers + self.log_frequency = log_frequency + self.entropy_weight = entropy_weight + self.skip_weight = skip_weight + self.baseline_decay = baseline_decay + self.ctrl_steps_aggregate = ctrl_steps_aggregate + self.grad_clip = grad_clip + self.workers = workers + self.ctrl_kwargs=ctrl_kwargs + self.ctrl_lr=ctrl_lr + self.n_warmup=n_warmup + self.model_lr = model_lr + self.model_wd = model_wd + timestamp=datetime.now().strftime('%m%d-%H-%M-%S') + self.log=open(f'../tmp/log-{timestamp}.txt','w') + def search(self, space: BaseSpace, dset, estimator): + self.model = space + self.dataset = dset#.to(self.device) + self.estimator = estimator + # replace choice + self.nas_modules = [] + + k2o = get_module_order(self.model) + replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) + replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) + self.nas_modules = sort_replaced_module(k2o, self.nas_modules) + + # to device + self.model = self.model.to(self.device) + # fields + self.nas_fields = [ReinforceField(name, len(module), + isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) + for name, module in self.nas_modules] + self.controller = ReinforceController(self.nas_fields,lstm_size=100,temperature=5.0,tanh_constant=2.5, **(self.ctrl_kwargs or {})) + self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) + # train + with tqdm(range(self.num_epochs)) as bar: + for i in bar: + l2=self._train_controller(i) + bar.set_postfix(reward_controller=l2) + + selection=self.export() + arch=space.export(selection,self.device) + print(selection,arch) + return arch + + def _train_controller(self, epoch): + self.model.eval() + self.controller.train() + self.ctrl_optim.zero_grad() + rewards=[] + baseline=None + with tqdm(range(self.ctrl_steps_aggregate)) as bar: + for ctrl_step in bar: + self._resample() + metric,loss=self._infer(mask='val') + + bar.set_postfix(acc=metric,loss=loss.item()) + self.log.write(f'{self.arch}\n{self.selection}\n{metric},{loss}\n') + self.log.flush() + reward =metric + rewards.append(reward) + + if self.entropy_weight: + reward += self.entropy_weight * self.controller.sample_entropy.item() + + if not baseline: + baseline= reward + else: + baseline = baseline * self.baseline_decay + reward * (1 - self.baseline_decay) + + loss = self.controller.sample_log_prob * (reward - baseline) + self.ctrl_optim.zero_grad() + loss.backward() + + self.ctrl_optim.step() + + bar.set_postfix(acc=metric,max_acc=max(rewards)) + return sum(rewards)/len(rewards) + + def _resample(self): + result = self.controller.resample() + self.arch=self.model.export(result,device=self.device) + self.selection=result + + def export(self): + self.controller.eval() + with torch.no_grad(): + return self.controller.resample() + + def _infer(self,mask='train'): + metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) + return metric, loss \ No newline at end of file diff --git a/autogl/module/nas/estimator/one_shot.py b/autogl/module/nas/estimator/one_shot.py index 4964947..9fc33be 100644 --- a/autogl/module/nas/estimator/one_shot.py +++ b/autogl/module/nas/estimator/one_shot.py @@ -31,9 +31,9 @@ class TrainEstimator(BaseEstimator): self.trainer=NodeClassificationFullTrainer( model=model, optimizer=torch.optim.Adam, - lr=0.01, - max_epoch=200, - early_stopping_round=200, + lr=0.005, + max_epoch=300, + early_stopping_round=30, weight_decay=5e-4, device="auto", init=False, diff --git a/autogl/module/nas/space/graph_nas_macro.py b/autogl/module/nas/space/graph_nas_macro.py index a59e8bb..963c13a 100644 --- a/autogl/module/nas/space/graph_nas_macro.py +++ b/autogl/module/nas/space/graph_nas_macro.py @@ -392,7 +392,7 @@ class GraphNasMacroNodeClfSpace(BaseSpace): self, hidden_dim: _typ.Optional[int] = 64, layer_number: _typ.Optional[int] = 2, - dropout: _typ.Optional[float] = 0.9, + dropout: _typ.Optional[float] = 0.6, input_dim: _typ.Optional[int] = None, output_dim: _typ.Optional[int] = None, ops: _typ.Tuple = None, diff --git a/examples/test_graph_nas_rl.py b/examples/test_graph_nas_rl.py index b5d4363..55e47db 100644 --- a/examples/test_graph_nas_rl.py +++ b/examples/test_graph_nas_rl.py @@ -10,7 +10,7 @@ from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClfSpace from autogl.module.train import Acc from autogl.module.nas.algorithm.enas import Enas -from autogl.module.nas.algorithm.rl import RL +from autogl.module.nas.algorithm.rl import RL,GraphNasRL from autogl.module.nas.estimator.one_shot import TrainEstimator from autogl.module.nas.algorithm.random_search import RandomSearch import logging @@ -25,7 +25,7 @@ if __name__ == '__main__': default_trainer=NodeClassificationFullTrainer( optimizer=torch.optim.Adam, lr=0.01, - max_epoch=200, + max_epoch=300, early_stopping_round=200, weight_decay=5e-4, device="auto", @@ -34,7 +34,7 @@ if __name__ == '__main__': loss="nll_loss", lr_scheduler_type=None,), # nas_algorithms=[RL(num_epochs=400)], - nas_algorithms=[RandomSearch(num_epochs=400)], + nas_algorithms=[GraphNasRL(num_epochs=100)], #nas_algorithms=[Darts(num_epochs=200)], nas_spaces=[GraphNasMacroNodeClfSpace(hidden_dim=16,search_act_con=True,layer_number=2)], nas_estimators=[TrainEstimator()]