|
|
|
@@ -219,8 +219,28 @@ class Optimizer(Cell): |
|
|
|
raise TypeError("Learning rate should be float, Tensor or Iterable.") |
|
|
|
return lr |
|
|
|
|
|
|
|
def _check_group_params(self, parameters): |
|
|
|
"""Check group params.""" |
|
|
|
parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] |
|
|
|
for group_param in parameters: |
|
|
|
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) |
|
|
|
if invalid_key: |
|
|
|
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') |
|
|
|
|
|
|
|
if 'order_params' in group_param.keys(): |
|
|
|
if len(group_param.keys()) > 1: |
|
|
|
raise ValueError("The order params dict in group parameters should " |
|
|
|
"only include the 'order_params' key.") |
|
|
|
if not isinstance(group_param['order_params'], Iterable): |
|
|
|
raise TypeError("The value of 'order_params' should be an Iterable type.") |
|
|
|
continue |
|
|
|
|
|
|
|
if not group_param['params']: |
|
|
|
raise ValueError("Optimizer got an empty group parameter list.") |
|
|
|
|
|
|
|
def _parse_group_params(self, parameters, learning_rate): |
|
|
|
"""Parse group params.""" |
|
|
|
self._check_group_params(parameters) |
|
|
|
if self.dynamic_lr: |
|
|
|
dynamic_lr_length = learning_rate.size() |
|
|
|
else: |
|
|
|
@@ -250,9 +270,6 @@ class Optimizer(Cell): |
|
|
|
if dynamic_lr_length not in (lr_length, 0): |
|
|
|
raise ValueError("The dynamic learning rate in group should be the same size.") |
|
|
|
|
|
|
|
if not group_param['params']: |
|
|
|
raise ValueError("Optimizer got an empty group parameter list.") |
|
|
|
|
|
|
|
dynamic_lr_length = lr_length |
|
|
|
self.dynamic_lr_length = dynamic_lr_length |
|
|
|
|
|
|
|
|