Browse Source

add check for group parameters

tags/v0.5.0-beta
guohongzilong 5 years ago
parent
commit
8585b55a65
1 changed files with 6 additions and 3 deletions
  1. +6
    -3
      mindspore/nn/optim/optimizer.py

+ 6
- 3
mindspore/nn/optim/optimizer.py View File

@@ -122,7 +122,7 @@ class Optimizer(Cell):
learning_rate = self._get_single_lr(learning_rate)
if isinstance(parameters[0], dict):
self.is_group = True
self.params = []
self.group_params = []
self.group_lr = []
self.group_weight_decay = []
self._init_group_params(parameters, learning_rate, weight_decay)
@@ -133,7 +133,7 @@ class Optimizer(Cell):
self.learning_rate = Parameter(learning_rate, name="learning_rate")

if self.is_group:
self.parameters = ParameterTuple(self.params)
self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay)
decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
@@ -240,7 +240,10 @@ class Optimizer(Cell):

params_store = []
for group_param in parameters:
self.params += group_param['params']
if not group_param['params']:
raise ValueError("Optimizer got an empty parameter list.")

self.group_params += group_param['params']
if 'lr' in group_param.keys():
params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor))



Loading…
Cancel
Save