Browse Source

Merge pull request #29 from THUMNLab/random_search

Random search
tags/v0.3.1
秦一鉴 GitHub 4 years ago
parent
commit
fd8e5ffa2d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 269 additions and 21 deletions
  1. +81
    -0
      autogl/module/nas/algorithm/random_search.py
  2. +177
    -12
      autogl/module/nas/algorithm/rl.py
  3. +3
    -3
      autogl/module/nas/estimator/one_shot.py
  4. +1
    -1
      autogl/module/nas/space/graph_nas_macro.py
  5. +7
    -5
      examples/test_graph_nas_rl.py

+ 81
- 0
autogl/module/nas/algorithm/random_search.py View File

@@ -0,0 +1,81 @@
import copy
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseNAS
from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module
from nni.nas.pytorch.fixed import apply_fixed_architecture
from tqdm import tqdm
_logger = logging.getLogger(__name__)
from .rl import PathSamplingLayerChoice,PathSamplingInputChoice
import numpy as np
class RSBox:
'''get selection space for model `space` '''
def __init__(self,space):
self.model = space
self.nas_modules = []
k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)
nm=self.nas_modules
selection_range={}
for k,v in nm:
selection_range[k]=len(v)
self.selection_dict=selection_range
space_size=np.prod(list(selection_range.values()))
print(f'Using random search Box. Total space size: {space_size}')
print('Searching Space:',selection_range)
def export(self):
return self.selection_dict #{k:v}, means action ranges 0 to v-1 for layer named k
def sample(self):
# uniformly sample
selection={}
sdict=self.export()
for k,v in sdict.items():
selection[k]=np.random.choice(range(v))
return selection

class RandomSearch(BaseNAS):
'''
uniformly search
'''
def __init__(self, device='cuda',num_epochs=400,disable_progress=False,*args,**kwargs):
super().__init__(device)
self.num_epochs=num_epochs
self.disable_progress=disable_progress
def search(self, space: BaseSpace, dset, estimator):
self.estimator=estimator
self.dataset=dset
self.space=space
self.box=RSBox(self.space)
arch_perfs=[]
cache={}
with tqdm(range(self.num_epochs),disable=self.disable_progress) as bar:
for i in bar:
selection=self.export()
# print(selection)
vec=tuple(list(selection.values()))
if vec not in cache:
self.arch=space.export(selection,self.device)
metric,loss=self._infer(mask='val')
arch_perfs.append([metric,selection])
cache[vec]=metric
bar.set_postfix(acc=metric,max_acc=max(cache.values()))
selection=arch_perfs[np.argmax([x[0] for x in arch_perfs])][1]
arch=space.export(selection,self.device)
return arch
def export(self):
arch=self.box.sample()
return arch

def _infer(self,mask='train'):
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask)
return metric, loss

+ 177
- 12
autogl/module/nas/algorithm/rl.py View File

@@ -11,6 +11,8 @@ from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module
from nni.nas.pytorch.fixed import apply_fixed_architecture
from tqdm import tqdm
from datetime import datetime
import numpy as np
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
@@ -229,7 +231,7 @@ class ReinforceController(nn.Module):

class RL(BaseNAS):
"""
ENAS trainer.
RL in GraphNas.

Parameters
----------
@@ -293,7 +295,7 @@ class RL(BaseNAS):
self.n_warmup=n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
self.log=open('log.txt','w')
self.log=open('../tmp/log.txt','w')
def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset#.to(self.device)
@@ -318,16 +320,6 @@ class RL(BaseNAS):
with tqdm(range(self.num_epochs)) as bar:
for i in bar:
l2=self._train_controller(i)

# try:
# l2=self._train_controller(i)
# except Exception as e:
# print(e)
# nm=self.nas_modules
# for i in range(len(nm)):
# print(nm[i][1].sampled)
# # import pdb
# # pdb.set_trace()
bar.set_postfix(reward_controller=l2)
selection=self.export()
@@ -382,3 +374,176 @@ class RL(BaseNAS):
def _infer(self,mask='train'):
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask)
return metric, loss

class GraphNasRL(BaseNAS):
"""
RL in GraphNas.

Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""

def __init__(self, device='cuda', workers=4,log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95,
ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None,n_warmup=100,model_lr=5e-3,model_wd=5e-4,topk=5,*args,**kwargs):
super().__init__(device)
self.device=device
self.num_epochs = kwargs.get("num_epochs", 10)
self.workers = workers
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
self.workers = workers
self.ctrl_kwargs=ctrl_kwargs
self.ctrl_lr=ctrl_lr
self.n_warmup=n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
timestamp=datetime.now().strftime('%m%d-%H-%M-%S')
self.log=open(f'../tmp/log-{timestamp}.txt','w')
self.hist=[]
self.topk=topk
def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset#.to(self.device)
self.estimator = estimator
# replace choice
self.nas_modules = []

k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)

# to device
self.model = self.model.to(self.device)
# fields
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields,lstm_size=100,temperature=5.0,tanh_constant=2.5, **(self.ctrl_kwargs or {}))
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr)
# train
with tqdm(range(self.num_epochs)) as bar:
for i in bar:
l2=self._train_controller(i)
bar.set_postfix(reward_controller=l2)
# selection=self.export()
selections=[x[1] for x in self.hist]
candidiate_accs=[-x[0] for x in self.hist]
print('candidiate accuracies',candidiate_accs)
selection=self._choose_best(selections)
arch=space.export(selection,self.device)
print(selection,arch)
return arch
def _choose_best(self,selections):
# graphnas use top 5 models, can evaluate 20 times epoch and choose the best.
results=[]
for selection in selections:
accs=[]
for i in tqdm(range(20)):
self.arch=self.model.export(selection,device=self.device)
metric,loss=self._infer(mask='val')
accs.append(metric)
result=np.mean(accs)
print('selection {} \n acc {:.4f} +- {:.4f}'.format(selection,np.mean(accs),np.std(accs)/np.sqrt(20)))
results.append(result)
best_selection=selections[np.argmax(results)]
return best_selection
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
rewards=[]
baseline=None
# diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000)
with tqdm(range(self.ctrl_steps_aggregate)) as bar:
for ctrl_step in bar:
self._resample()
metric,loss=self._infer(mask='val')

# bar.set_postfix(acc=metric,loss=loss.item())
self.log.write(f'{self.arch}\n{self.selection}\n{metric},{loss}\n')
self.log.flush()
# diff: not do reward shaping as in graphnas code
reward =metric
self.hist.append([-metric,self.selection])
if len(self.hist)>self.topk:
self.hist.sort(key=lambda x:x[0])
self.hist.pop()
rewards.append(reward)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()

if not baseline:
baseline= reward
else:
baseline = baseline * self.baseline_decay + reward * (1 - self.baseline_decay)

loss = self.controller.sample_log_prob * (reward - baseline)
self.ctrl_optim.zero_grad()
loss.backward()
self.ctrl_optim.step()
bar.set_postfix(acc=metric,max_acc=max(rewards))
return sum(rewards)/len(rewards)

def _resample(self):
result = self.controller.resample()
self.arch=self.model.export(result,device=self.device)
self.selection=result

def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()

def _infer(self,mask='train'):
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask)
return metric, loss

+ 3
- 3
autogl/module/nas/estimator/one_shot.py View File

@@ -31,9 +31,9 @@ class TrainEstimator(BaseEstimator):
self.trainer=NodeClassificationFullTrainer(
model=model,
optimizer=torch.optim.Adam,
lr=0.01,
max_epoch=200,
early_stopping_round=200,
lr=0.005,
max_epoch=300,
early_stopping_round=30,
weight_decay=5e-4,
device="auto",
init=False,


+ 1
- 1
autogl/module/nas/space/graph_nas_macro.py View File

@@ -392,7 +392,7 @@ class GraphNasMacroNodeClfSpace(BaseSpace):
self,
hidden_dim: _typ.Optional[int] = 64,
layer_number: _typ.Optional[int] = 2,
dropout: _typ.Optional[float] = 0.9,
dropout: _typ.Optional[float] = 0.6,
input_dim: _typ.Optional[int] = None,
output_dim: _typ.Optional[int] = None,
ops: _typ.Tuple = None,


+ 7
- 5
examples/test_graph_nas_rl.py View File

@@ -10,8 +10,9 @@ from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClfSpace
from autogl.module.train import Acc
from autogl.module.nas.algorithm.enas import Enas
from autogl.module.nas.algorithm.rl import RL
from autogl.module.nas.algorithm.rl import RL,GraphNasRL
from autogl.module.nas.estimator.one_shot import TrainEstimator
from autogl.module.nas.algorithm.random_search import RandomSearch
import logging
if __name__ == '__main__':
logging.getLogger().setLevel(logging.WARNING)
@@ -23,16 +24,17 @@ if __name__ == '__main__':
ensemble_module=None,
default_trainer=NodeClassificationFullTrainer(
optimizer=torch.optim.Adam,
lr=0.01,
max_epoch=200,
early_stopping_round=200,
lr=0.005,
max_epoch=300,
early_stopping_round=20,
weight_decay=5e-4,
device="auto",
init=False,
feval=['acc'],
loss="nll_loss",
lr_scheduler_type=None,),
nas_algorithms=[RL(num_epochs=400)],
# nas_algorithms=[RL(num_epochs=400)],
nas_algorithms=[GraphNasRL(num_epochs=20)],
#nas_algorithms=[Darts(num_epochs=200)],
nas_spaces=[GraphNasMacroNodeClfSpace(hidden_dim=16,search_act_con=True,layer_number=2)],
nas_estimators=[TrainEstimator()]


Loading…
Cancel
Save