From 7fed250ffddeba42dd2022f5012acbe5dd1cc0e3 Mon Sep 17 00:00:00 2001 From: yanghaiqie Date: Thu, 26 Aug 2021 02:25:03 +0800 Subject: [PATCH] add model parameter count to reward function for graphnas --- autogl/module/nas/algorithm/rl.py | 22 ++++++++++++++++++---- autogl/module/nas/space/base.py | 7 +++++++ autogl/module/nas/space/graph_nas_macro.py | 5 +++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/autogl/module/nas/algorithm/rl.py b/autogl/module/nas/algorithm/rl.py index 2f48ba4..c143094 100644 --- a/autogl/module/nas/algorithm/rl.py +++ b/autogl/module/nas/algorithm/rl.py @@ -523,6 +523,8 @@ class GraphNasRL(BaseNAS): model_wd=5e-4, topk=5, disable_progress=True, + param_size_weight=None, + param_size_limit=None, ): super().__init__(device) self.device = device @@ -541,6 +543,9 @@ class GraphNasRL(BaseNAS): self.hist = [] self.topk = topk self.disable_progress = disable_progress + # TODO: new a class to describe the hardware-aware method + self.param_size_weight = param_size_weight + self.param_size_limit = param_size_limit def search(self, space: BaseSpace, dset, estimator): self.model = space @@ -628,10 +633,19 @@ class GraphNasRL(BaseNAS): LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}") # diff: not do reward shaping as in graphnas code reward = metric - self.hist.append([-metric, self.selection]) - if len(self.hist) > self.topk: - self.hist.sort(key=lambda x: x[0]) - self.hist.pop() + # TODO: change + model_info = self.arch.model.get_model_info() + print(f"model_info: {model_info}") + if self.param_size_weight is not None: + reward -= self.param_size_weight * model_info["param"] + if ( + self.param_size_limit is None + or model_info["param"] < self.param_size_limit + ): + self.hist.append([-metric, self.selection]) + if len(self.hist) > self.topk: + self.hist.sort(key=lambda x: x[0]) + self.hist.pop() rewards.append(reward) if self.entropy_weight: diff --git a/autogl/module/nas/space/base.py b/autogl/module/nas/space/base.py index 02d4ca6..330b3b4 100644 --- a/autogl/module/nas/space/base.py +++ b/autogl/module/nas/space/base.py @@ -145,6 +145,9 @@ class BoxModel(BaseModel): ret_self.to(self.device) return ret_self + def __repr__(self) -> str: + return str(self.model.get_model_info()) + @property def model(self): return self._model @@ -200,6 +203,10 @@ class BaseSpace(nn.Module): """ raise NotImplementedError() + def get_model_info(self): + # TODO: write zhushi + return {} + def instantiate(self): """ Instantiate the space, reset default key for the mutables here/ diff --git a/autogl/module/nas/space/graph_nas_macro.py b/autogl/module/nas/space/graph_nas_macro.py index 4929399..68917d6 100644 --- a/autogl/module/nas/space/graph_nas_macro.py +++ b/autogl/module/nas/space/graph_nas_macro.py @@ -739,3 +739,8 @@ class GraphNet(BaseSpace): key = f"layer_{i}_fc_{bn.weight.size(0)}" if key in param: self.bns[i] = param[key] + + def get_model_info(self): + param_size = sum(x.numel() for x in self.parameters()) + info = {"param": param_size} + return info