|
|
@@ -10,6 +10,7 @@ from .base import BaseNAS |
|
|
from ..space import BaseSpace |
|
|
from ..space import BaseSpace |
|
|
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module |
|
|
from ..utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, get_module_order, sort_replaced_module |
|
|
from nni.nas.pytorch.fixed import apply_fixed_architecture |
|
|
from nni.nas.pytorch.fixed import apply_fixed_architecture |
|
|
|
|
|
from tqdm import tqdm |
|
|
_logger = logging.getLogger(__name__) |
|
|
_logger = logging.getLogger(__name__) |
|
|
def _get_mask(sampled, total): |
|
|
def _get_mask(sampled, total): |
|
|
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] |
|
|
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)] |
|
|
@@ -312,9 +313,21 @@ class Enas(BaseNAS): |
|
|
self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {})) |
|
|
self.controller = ReinforceController(self.nas_fields, **(self.ctrl_kwargs or {})) |
|
|
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) |
|
|
self.ctrl_optim = torch.optim.Adam(self.controller.parameters(), lr=self.ctrl_lr) |
|
|
# train |
|
|
# train |
|
|
for i in range(self.num_epochs): |
|
|
|
|
|
self._train_model(i) |
|
|
|
|
|
self._train_controller(i) |
|
|
|
|
|
|
|
|
with tqdm(range(self.num_epochs)) as bar: |
|
|
|
|
|
for i in bar: |
|
|
|
|
|
try: |
|
|
|
|
|
l1=self._train_model(i) |
|
|
|
|
|
l2=self._train_controller(i) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(e) |
|
|
|
|
|
nm=self.nas_modules |
|
|
|
|
|
for i in range(len(nm)): |
|
|
|
|
|
print(nm[i][1].sampled) |
|
|
|
|
|
import pdb |
|
|
|
|
|
pdb.set_trace() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bar.set_postfix(loss_model=l1,reward_controller=l2) |
|
|
|
|
|
|
|
|
selection=self.export() |
|
|
selection=self.export() |
|
|
return space.export(selection,self.device) |
|
|
return space.export(selection,self.device) |
|
|
@@ -329,16 +342,19 @@ class Enas(BaseNAS): |
|
|
if self.grad_clip > 0: |
|
|
if self.grad_clip > 0: |
|
|
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) |
|
|
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) |
|
|
self.model_optim.step() |
|
|
self.model_optim.step() |
|
|
|
|
|
return loss.item() |
|
|
|
|
|
|
|
|
def _train_controller(self, epoch): |
|
|
def _train_controller(self, epoch): |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
self.controller.train() |
|
|
self.controller.train() |
|
|
self.ctrl_optim.zero_grad() |
|
|
self.ctrl_optim.zero_grad() |
|
|
|
|
|
rewards=[] |
|
|
for ctrl_step in range(self.ctrl_steps_aggregate): |
|
|
for ctrl_step in range(self.ctrl_steps_aggregate): |
|
|
self._resample() |
|
|
self._resample() |
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
metric,loss=self._infer() |
|
|
metric,loss=self._infer() |
|
|
reward =-metric # todo : now metric is loss |
|
|
reward =-metric # todo : now metric is loss |
|
|
|
|
|
rewards.append(reward) |
|
|
if self.entropy_weight: |
|
|
if self.entropy_weight: |
|
|
reward += self.entropy_weight * self.controller.sample_entropy.item() |
|
|
reward += self.entropy_weight * self.controller.sample_entropy.item() |
|
|
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) |
|
|
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) |
|
|
@@ -357,6 +373,7 @@ class Enas(BaseNAS): |
|
|
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0: |
|
|
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0: |
|
|
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, |
|
|
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, |
|
|
ctrl_step + 1, self.ctrl_steps_aggregate) |
|
|
ctrl_step + 1, self.ctrl_steps_aggregate) |
|
|
|
|
|
return (sum(rewards)/len(rewards)).item() |
|
|
|
|
|
|
|
|
def _resample(self): |
|
|
def _resample(self): |
|
|
result = self.controller.resample() |
|
|
result = self.controller.resample() |
|
|
|