|
|
|
@@ -11,6 +11,8 @@ from ..space import BaseSpace |
|
|
|
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module |
|
|
|
from nni.nas.pytorch.fixed import apply_fixed_architecture |
|
|
|
from tqdm import tqdm |
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
def _get_mask(sampled, total): |
|
|
|
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] |
|
|
|
@@ -229,7 +231,7 @@ class ReinforceController(nn.Module): |
|
|
|
|
|
|
|
class RL(BaseNAS): |
|
|
|
""" |
|
|
|
ENAS trainer. |
|
|
|
RL in GraphNas. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
@@ -293,7 +295,7 @@ class RL(BaseNAS): |
|
|
|
self.n_warmup=n_warmup |
|
|
|
self.model_lr = model_lr |
|
|
|
self.model_wd = model_wd |
|
|
|
self.log=open('log.txt','w') |
|
|
|
self.log=open('../tmp/log.txt','w') |
|
|
|
def search(self, space: BaseSpace, dset, estimator): |
|
|
|
self.model = space |
|
|
|
self.dataset = dset#.to(self.device) |
|
|
|
@@ -318,16 +320,6 @@ class RL(BaseNAS): |
|
|
|
with tqdm(range(self.num_epochs)) as bar: |
|
|
|
for i in bar: |
|
|
|
l2=self._train_controller(i) |
|
|
|
|
|
|
|
# try: |
|
|
|
# l2=self._train_controller(i) |
|
|
|
# except Exception as e: |
|
|
|
# print(e) |
|
|
|
# nm=self.nas_modules |
|
|
|
# for i in range(len(nm)): |
|
|
|
# print(nm[i][1].sampled) |
|
|
|
# # import pdb |
|
|
|
# # pdb.set_trace() |
|
|
|
bar.set_postfix(reward_controller=l2) |
|
|
|
|
|
|
|
selection=self.export() |
|
|
|
@@ -382,3 +374,150 @@ class RL(BaseNAS): |
|
|
|
def _infer(self,mask='train'): |
|
|
|
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) |
|
|
|
return metric, loss |
|
|
|
|
|
|
|
|
|
|
|
class GraphNasRL(BaseNAS): |
|
|
|
""" |
|
|
|
RL in GraphNas. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
model : nn.Module |
|
|
|
PyTorch model to be trained. |
|
|
|
loss : callable |
|
|
|
Receives logits and ground truth label, return a loss tensor. |
|
|
|
metrics : callable |
|
|
|
Receives logits and ground truth label, return a dict of metrics. |
|
|
|
reward_function : callable |
|
|
|
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. |
|
|
|
optimizer : Optimizer |
|
|
|
The optimizer used for optimizing the model. |
|
|
|
num_epochs : int |
|
|
|
Number of epochs planned for training. |
|
|
|
dataset : Dataset |
|
|
|
Dataset for training. Will be split for training weights and architecture weights. |
|
|
|
batch_size : int |
|
|
|
Batch size. |
|
|
|
workers : int |
|
|
|
Workers for data loading. |
|
|
|
device : torch.device |
|
|
|
``torch.device("cpu")`` or ``torch.device("cuda")``. |
|
|
|
log_frequency : int |
|
|
|
Step count per logging. |
|
|
|
grad_clip : float |
|
|
|
Gradient clipping. Set to 0 to disable. Default: 5. |
|
|
|
entropy_weight : float |
|
|
|
Weight of sample entropy loss. |
|
|
|
skip_weight : float |
|
|
|
Weight of skip penalty loss. |
|
|
|
baseline_decay : float |
|
|
|
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. |
|
|
|
ctrl_lr : float |
|
|
|
Learning rate for RL controller. |
|
|
|
ctrl_steps_aggregate : int |
|
|
|
Number of steps that will be aggregated into one mini-batch for RL controller. |
|
|
|
ctrl_steps : int |
|
|
|
Number of mini-batches for each epoch of RL controller learning. |
|
|
|
ctrl_kwargs : dict |
|
|
|
Optional kwargs that will be passed to :class:`ReinforceController`. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, device='cuda', workers=4,log_frequency=None, |
|
|
|
grad_clip=5., entropy_weight=0.0001, skip_weight=0, baseline_decay=0.95, |
|
|
|
ctrl_lr=0.00035, ctrl_steps_aggregate=100, ctrl_kwargs=None,n_warmup=100,model_lr=5e-3,model_wd=5e-4,*args,**kwargs): |
|
|
|
super().__init__(device) |
|
|
|
self.device=device |
|
|
|
self.num_epochs = kwargs.get("num_epochs", 10) |
|
|
|
self.workers = workers |
|
|
|
self.log_frequency = log_frequency |
|
|
|
self.entropy_weight = entropy_weight |
|
|
|
self.skip_weight = skip_weight |
|
|
|
self.baseline_decay = baseline_decay |
|
|
|
self.ctrl_steps_aggregate = ctrl_steps_aggregate |
|
|
|
self.grad_clip = grad_clip |
|
|
|
self.workers = workers |
|
|
|
self.ctrl_kwargs=ctrl_kwargs |
|
|
|
self.ctrl_lr=ctrl_lr |
|
|
|
self.n_warmup=n_warmup |
|
|
|
self.model_lr = model_lr |
|
|
|
self.model_wd = model_wd |
|
|
|
timestamp=datetime.now().strftime('%m%d-%H-%M-%S') |
|
|
|
self.log=open(f'../tmp/log-{timestamp}.txt','w') |
|
|
|
def search(self, space: BaseSpace, dset, estimator): |
|
|
|
self.model = space |
|
|
|
self.dataset = dset#.to(self.device) |
|
|
|
self.estimator = estimator |
|
|
|
# replace choice |
|
|
|
self.nas_modules = [] |
|
|
|
|
|
|
|
k2o = get_module_order(self.model) |
|
|
|
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) |
|
|
|
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) |
|
|
|
self.nas_modules = sort_replaced_module(k2o, self.nas_modules) |
|
|
|
|
|
|
|
# to device |
|
|
|
self.model = self.model.to(self.device) |
|
|
|
# fields |
|
|
|
self.nas_fields = [ReinforceField(name, len(module), |
|
|
|
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) |
|
|
|
for name, module in self.nas_modules] |
|
|
|
self.controller = ReinforceController(self.nas_fields,lstm_size=100,temperature=5.0,tanh_constant=2.5, **(self.ctrl_kwargs or {})) |
|
|
|
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) |
|
|
|
# train |
|
|
|
with tqdm(range(self.num_epochs)) as bar: |
|
|
|
for i in bar: |
|
|
|
l2=self._train_controller(i) |
|
|
|
bar.set_postfix(reward_controller=l2) |
|
|
|
|
|
|
|
selection=self.export() |
|
|
|
arch=space.export(selection,self.device) |
|
|
|
print(selection,arch) |
|
|
|
return arch |
|
|
|
|
|
|
|
def _train_controller(self, epoch): |
|
|
|
self.model.eval() |
|
|
|
self.controller.train() |
|
|
|
self.ctrl_optim.zero_grad() |
|
|
|
rewards=[] |
|
|
|
baseline=None |
|
|
|
with tqdm(range(self.ctrl_steps_aggregate)) as bar: |
|
|
|
for ctrl_step in bar: |
|
|
|
self._resample() |
|
|
|
metric,loss=self._infer(mask='val') |
|
|
|
|
|
|
|
bar.set_postfix(acc=metric,loss=loss.item()) |
|
|
|
self.log.write(f'{self.arch}\n{self.selection}\n{metric},{loss}\n') |
|
|
|
self.log.flush() |
|
|
|
reward =metric |
|
|
|
rewards.append(reward) |
|
|
|
|
|
|
|
if self.entropy_weight: |
|
|
|
reward += self.entropy_weight * self.controller.sample_entropy.item() |
|
|
|
|
|
|
|
if not baseline: |
|
|
|
baseline= reward |
|
|
|
else: |
|
|
|
baseline = baseline * self.baseline_decay + reward * (1 - self.baseline_decay) |
|
|
|
|
|
|
|
loss = self.controller.sample_log_prob * (reward - baseline) |
|
|
|
self.ctrl_optim.zero_grad() |
|
|
|
loss.backward() |
|
|
|
|
|
|
|
self.ctrl_optim.step() |
|
|
|
|
|
|
|
bar.set_postfix(acc=metric,max_acc=max(rewards)) |
|
|
|
return sum(rewards)/len(rewards) |
|
|
|
|
|
|
|
def _resample(self): |
|
|
|
result = self.controller.resample() |
|
|
|
self.arch=self.model.export(result,device=self.device) |
|
|
|
self.selection=result |
|
|
|
|
|
|
|
def export(self): |
|
|
|
self.controller.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
return self.controller.resample() |
|
|
|
|
|
|
|
def _infer(self,mask='train'): |
|
|
|
metric, loss = self.estimator.infer(self.arch, self.dataset,mask=mask) |
|
|
|
return metric, loss |