|
|
|
@@ -122,7 +122,7 @@ class Optimizer(Cell): |
|
|
|
learning_rate = self._get_single_lr(learning_rate) |
|
|
|
if isinstance(parameters[0], dict): |
|
|
|
self.is_group = True |
|
|
|
self.params = [] |
|
|
|
self.group_params = [] |
|
|
|
self.group_lr = [] |
|
|
|
self.group_weight_decay = [] |
|
|
|
self._init_group_params(parameters, learning_rate, weight_decay) |
|
|
|
@@ -133,7 +133,7 @@ class Optimizer(Cell): |
|
|
|
self.learning_rate = Parameter(learning_rate, name="learning_rate") |
|
|
|
|
|
|
|
if self.is_group: |
|
|
|
self.parameters = ParameterTuple(self.params) |
|
|
|
self.parameters = ParameterTuple(self.group_params) |
|
|
|
self.weight_decay = tuple(self.group_weight_decay) |
|
|
|
decay_filter = lambda x: x > 0 |
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) |
|
|
|
@@ -240,7 +240,10 @@ class Optimizer(Cell): |
|
|
|
|
|
|
|
params_store = [] |
|
|
|
for group_param in parameters: |
|
|
|
self.params += group_param['params'] |
|
|
|
if not group_param['params']: |
|
|
|
raise ValueError("Optimizer got an empty parameter list.") |
|
|
|
|
|
|
|
self.group_params += group_param['params'] |
|
|
|
if 'lr' in group_param.keys(): |
|
|
|
params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) |
|
|
|
|
|
|
|
|