Browse Source

add graphnasrl

tags/v0.3.1
wondergo2017 4 years ago
parent
commit
aeef16fcca
4 changed files with 158 additions and 19 deletions
  1. +151
    -12
      autogl/module/nas/algorithm/rl.py
  2. +3
    -3
      autogl/module/nas/estimator/one_shot.py
  3. +1
    -1
      autogl/module/nas/space/graph_nas_macro.py
  4. +3
    -3
      examples/test_graph_nas_rl.py

+ 151
- 12
autogl/module/nas/algorithm/rl.py View File

@@ -11,6 +11,8 @@ from ..space import BaseSpace
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module
from nni.nas.pytorch.fixed import apply_fixed_architecture
from tqdm import tqdm
from datetime import datetime

_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
@@ -229,7 +231,7 @@ class ReinforceController(nn.Module):

class RL(BaseNAS):
"""
ENAS trainer.
RL in GraphNas.

Parameters
----------
@@ -293,7 +295,7 @@ class RL(BaseNAS):
self.n_warmup=n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
self.log=open('log.txt','w')
self.log=open('../tmp/log.txt','w')
def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset#.to(self.device)
@@ -318,16 +320,6 @@ class RL(BaseNAS):
with tqdm(range(self.num_epochs)) as bar:
for i in bar:
l2=self._train_controller(i)

# try:
# l2=self._train_controller(i)
# except Exception as e:
# print(e)
# nm=self.nas_modules
# for i in range(len(nm)):
# print(nm[i][1].sampled)
# # import pdb
# # pdb.set_trace()
bar.set_postfix(reward_controller=l2)
selection=self.export()
@@ -382,3 +374,150 @@ class RL(BaseNAS):
def _infer(self,mask='train'):
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask)
return metric, loss


class GraphNasRL(BaseNAS):
"""
RL in GraphNas.

Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""

def __init__(self, device='cuda', workers=4,log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95,
ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None,n_warmup=100,model_lr=5e-3,model_wd=5e-4,*args,**kwargs):
super().__init__(device)
self.device=device
self.num_epochs = kwargs.get("num_epochs", 10)
self.workers = workers
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
self.workers = workers
self.ctrl_kwargs=ctrl_kwargs
self.ctrl_lr=ctrl_lr
self.n_warmup=n_warmup
self.model_lr = model_lr
self.model_wd = model_wd
timestamp=datetime.now().strftime('%m%d-%H-%M-%S')
self.log=open(f'../tmp/log-{timestamp}.txt','w')
def search(self, space: BaseSpace, dset, estimator):
self.model = space
self.dataset = dset#.to(self.device)
self.estimator = estimator
# replace choice
self.nas_modules = []

k2o = get_module_order(self.model)
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
self.nas_modules = sort_replaced_module(k2o, self.nas_modules)

# to device
self.model = self.model.to(self.device)
# fields
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields,lstm_size=100,temperature=5.0,tanh_constant=2.5, **(self.ctrl_kwargs or {}))
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr)
# train
with tqdm(range(self.num_epochs)) as bar:
for i in bar:
l2=self._train_controller(i)
bar.set_postfix(reward_controller=l2)
selection=self.export()
arch=space.export(selection,self.device)
print(selection,arch)
return arch
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
self.ctrl_optim.zero_grad()
rewards=[]
baseline=None
with tqdm(range(self.ctrl_steps_aggregate)) as bar:
for ctrl_step in bar:
self._resample()
metric,loss=self._infer(mask='val')

bar.set_postfix(acc=metric,loss=loss.item())
self.log.write(f'{self.arch}\n{self.selection}\n{metric},{loss}\n')
self.log.flush()
reward =metric
rewards.append(reward)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()

if not baseline:
baseline= reward
else:
baseline = baseline * self.baseline_decay + reward * (1 - self.baseline_decay)

loss = self.controller.sample_log_prob * (reward - baseline)
self.ctrl_optim.zero_grad()
loss.backward()
self.ctrl_optim.step()
bar.set_postfix(acc=metric,max_acc=max(rewards))
return sum(rewards)/len(rewards)

def _resample(self):
result = self.controller.resample()
self.arch=self.model.export(result,device=self.device)
self.selection=result

def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()

def _infer(self,mask='train'):
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask)
return metric, loss

+ 3
- 3
autogl/module/nas/estimator/one_shot.py View File

@@ -31,9 +31,9 @@ class TrainEstimator(BaseEstimator):
self.trainer=NodeClassificationFullTrainer(
model=model,
optimizer=torch.optim.Adam,
lr=0.01,
max_epoch=200,
early_stopping_round=200,
lr=0.005,
max_epoch=300,
early_stopping_round=30,
weight_decay=5e-4,
device="auto",
init=False,


+ 1
- 1
autogl/module/nas/space/graph_nas_macro.py View File

@@ -392,7 +392,7 @@ class GraphNasMacroNodeClfSpace(BaseSpace):
self,
hidden_dim: _typ.Optional[int] = 64,
layer_number: _typ.Optional[int] = 2,
dropout: _typ.Optional[float] = 0.9,
dropout: _typ.Optional[float] = 0.6,
input_dim: _typ.Optional[int] = None,
output_dim: _typ.Optional[int] = None,
ops: _typ.Tuple = None,


+ 3
- 3
examples/test_graph_nas_rl.py View File

@@ -10,7 +10,7 @@ from autogl.module.nas.space.graph_nas import GraphNasNodeClassificationSpace
from autogl.module.nas.space.graph_nas_macro import GraphNasMacroNodeClfSpace
from autogl.module.train import Acc
from autogl.module.nas.algorithm.enas import Enas
from autogl.module.nas.algorithm.rl import RL
from autogl.module.nas.algorithm.rl import RL,GraphNasRL
from autogl.module.nas.estimator.one_shot import TrainEstimator
from autogl.module.nas.algorithm.random_search import RandomSearch
import logging
@@ -25,7 +25,7 @@ if __name__ == '__main__':
default_trainer=NodeClassificationFullTrainer(
optimizer=torch.optim.Adam,
lr=0.01,
max_epoch=200,
max_epoch=300,
early_stopping_round=200,
weight_decay=5e-4,
device="auto",
@@ -34,7 +34,7 @@ if __name__ == '__main__':
loss="nll_loss",
lr_scheduler_type=None,),
# nas_algorithms=[RL(num_epochs=400)],
nas_algorithms=[RandomSearch(num_epochs=400)],
nas_algorithms=[GraphNasRL(num_epochs=100)],
#nas_algorithms=[Darts(num_epochs=200)],
nas_spaces=[GraphNasMacroNodeClfSpace(hidden_dim=16,search_act_con=True,layer_number=2)],
nas_estimators=[TrainEstimator()]


Loading…
Cancel
Save