|
|
|
@@ -523,6 +523,8 @@ class GraphNasRL(BaseNAS): |
|
|
|
model_wd=5e-4, |
|
|
|
topk=5, |
|
|
|
disable_progress=False, |
|
|
|
param_size_weight=None, |
|
|
|
param_size_limit=None, |
|
|
|
): |
|
|
|
super().__init__(device) |
|
|
|
self.device = device |
|
|
|
@@ -541,6 +543,9 @@ class GraphNasRL(BaseNAS): |
|
|
|
self.hist = [] |
|
|
|
self.topk = topk |
|
|
|
self.disable_progress = disable_progress |
|
|
|
# TODO: new a class to describe the hardware-aware method |
|
|
|
self.param_size_weight = param_size_weight |
|
|
|
self.param_size_limit = param_size_limit |
|
|
|
|
|
|
|
def search(self, space: BaseSpace, dset, estimator): |
|
|
|
self.model = space |
|
|
|
@@ -628,10 +633,19 @@ class GraphNasRL(BaseNAS): |
|
|
|
LOGGER.debug(f"{self.arch}\n{self.selection}\n{metric},{loss}") |
|
|
|
# diff: not do reward shaping as in graphnas code |
|
|
|
reward = metric |
|
|
|
self.hist.append([-metric, self.selection]) |
|
|
|
if len(self.hist) > self.topk: |
|
|
|
self.hist.sort(key=lambda x: x[0]) |
|
|
|
self.hist.pop() |
|
|
|
# TODO: change |
|
|
|
model_info = self.arch.model.get_model_info() |
|
|
|
print(f"model_info: {model_info}") |
|
|
|
if self.param_size_weight is not None: |
|
|
|
reward -= self.param_size_weight * model_info["param"] |
|
|
|
if ( |
|
|
|
self.param_size_limit is None |
|
|
|
or model_info["param"] < self.param_size_limit |
|
|
|
): |
|
|
|
self.hist.append([-metric, self.selection]) |
|
|
|
if len(self.hist) > self.topk: |
|
|
|
self.hist.sort(key=lambda x: x[0]) |
|
|
|
self.hist.pop() |
|
|
|
rewards.append(reward) |
|
|
|
|
|
|
|
if self.entropy_weight: |
|
|
|
|