from .rl import * import numpy as np class AGNNReinforceController(ReinforceController): def resample(self,search_fields,selection): # search_fields act as group of fields in the paper (like activation group) self._initialize() result = selection.copy() # 1. update initial state with fields not searched for field in self.fields: if field not in search_fields: self._update_state(field,selection[field.name]) # 2. get probability of field to search for field in search_fields: result[field.name] = self._sample_single(field) return result def _update_state(self,field,sampled): # use unsearched fields as rnn history to update state self._lstm_next_step() self._inputs = self.embedding[field.name](torch.LongTensor([sampled]).to(self._inputs.device)) class AGNNActionGuider(nn.Module): def __init__(self, fields, groups,guide_type, **controllargs): super(AGNNActionGuider, self).__init__() # create independent controllers for each group controllers=[AGNNReinforceController(fields,**controllargs) for group in groups] self.controllers=nn.ModuleList(controllers) self.fields=fields self.groups=groups self.guide_type = guide_type def dummy_selection(self): # create dummy selection result=dict() for field in self.fields: result[field.name]=0 return result def resample(self,selection): entropys=[] new_selections=[] sample_probs=[] for idx,cont in enumerate(self.controllers): cont=self.controllers[idx] group=self.groups[idx] new_selection=cont.resample(group,selection) new_selections.append(new_selection) entropy=cont.sample_entropy entropys.append(entropy) sample_probs.append(cont.sample_log_prob) print(f'$$entropys {entropys}') if self.guide_type==0: # use the most uncertain one idx=np.argmax(entropys) elif self.guide_type==1: # or sample by using entropy idx=torch.multinomial(F.softmax(torch.tensor(entropys),dim=0),1).item() else: assert False,f"Not implemented guide type {self.guide_type}" group=self.groups[idx] print(f'$$select group {group}') new_selection=new_selections[idx] self.sample_log_prob=sample_probs[idx] self.sample_entropy=entropys[idx] print(f'$$new selection {new_selection}') return new_selection @register_nas_algo("agnn") class AGNNRL(GraphNasRL): def __init__(self,guide_type=1,*args,**kwargs): super(AGNNRL, self).__init__(*args,**kwargs) self.guide_type = guide_type def search(self, space: BaseSpace, dset, estimator): self.model = space self.dataset = dset # .to(self.device) self.estimator = estimator # replace choice self.nas_modules = [] k2o = get_module_order(self.model) replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) self.nas_modules = sort_replaced_module(k2o, self.nas_modules) # to device self.model = self.model.to(self.device) # fields self.nas_fields = [ ReinforceField( name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1, ) for name, module in self.nas_modules ] # create groups tags='op in act concat'.split() groups={tag:[] for tag in tags} for field in self.nas_fields: for tag in tags: if tag in field.name: groups[tag].append(field) groups=[x for x in groups.values() if x] # controller self.controller = AGNNActionGuider( self.nas_fields, groups, self.guide_type, lstm_size=100, temperature=5.0, tanh_constant=2.5, **(self.ctrl_kwargs or {}), ) self.ctrl_optim = torch.optim.Adam( self.controller.parameters(), lr=self.ctrl_lr ) # init selection (acc,selection) self.best_selection=[0,self.controller.dummy_selection()] # train with tqdm(range(self.num_epochs), disable=self.disable_progress) as bar: for i in bar: l2 = self._train_controller(i) bar.set_postfix(reward_controller=l2) selection=self.export() # selections = [x[1] for x in self.hist] # candidiate_accs = [-x[0] for x in self.hist] # # print('candidiate accuracies',candidiate_accs) # selection = self._choose_best(selections) arch = space.parse_model(selection, self.device) print(selection,arch) return arch def _train_controller(self, epoch): self.model.eval() self.controller.train() self.ctrl_optim.zero_grad() rewards = [] selections=[] # baseline = None baseline=self.best_selection[0] # diff: graph nas train 100 and derive 100 for every epoch(10 epochs), we just train 100(20 epochs). totol num of samples are same (2000) with tqdm( range(self.ctrl_steps_aggregate), disable=self.disable_progress ) as bar: for ctrl_step in bar: self._resample() selections.append(self.selection.copy()) metric, loss, hardware_metric = self._infer(mask="val") reward = metric # bar.set_postfix(acc=metric,loss=loss.item()) LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}") # diff: not do reward shaping as in graphnas code if ( self.hardware_metric_limit is None or hardware_metric[0] < self.hardware_metric_limit ): self.hist.append([-metric, self.selection]) self.allhist.append([-metric, self.selection]) if len(self.hist) > self.topk: self.hist.sort(key=lambda x: x[0]) self.hist.pop() rewards.append(reward) if self.entropy_weight: reward += ( self.entropy_weight * self.controller.sample_entropy.item() ) if not baseline: baseline = reward else: baseline = baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay ) loss = self.controller.sample_log_prob * (reward - baseline) self.ctrl_optim.zero_grad() loss.backward() self.ctrl_optim.step() bar.set_postfix(acc=metric, max_acc=max(rewards)) # conserative explorer: update the best selection idx=np.argmax(rewards) best_reward=rewards[idx] best_selection=selections[idx] if best_reward>self.best_selection[0]: self.best_selection=[best_reward,best_selection] print(f'$$best selection: {self.best_selection}') LOGGER.info("epoch:{}, mean rewards:{}".format(epoch, sum(rewards) / len(rewards))) return sum(rewards) / len(rewards) def _resample(self): result = self.controller.resample(self.best_selection[1]) self.arch = self.model.parse_model(result, device=self.device) self.selection = result def export(self): # self.controller.eval() # with torch.no_grad(): # return self.controller.resample() return self.best_selection[1]