From 8585b55a65dd6d2da513cf33ff643d10a2a5a81b Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Tue, 26 May 2020 22:15:23 +0800 Subject: [PATCH] add check for group parameters --- mindspore/nn/optim/optimizer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 788a7d2754..988fa45680 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -122,7 +122,7 @@ class Optimizer(Cell): learning_rate = self._get_single_lr(learning_rate) if isinstance(parameters[0], dict): self.is_group = True - self.params = [] + self.group_params = [] self.group_lr = [] self.group_weight_decay = [] self._init_group_params(parameters, learning_rate, weight_decay) @@ -133,7 +133,7 @@ class Optimizer(Cell): self.learning_rate = Parameter(learning_rate, name="learning_rate") if self.is_group: - self.parameters = ParameterTuple(self.params) + self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) @@ -240,7 +240,10 @@ class Optimizer(Cell): params_store = [] for group_param in parameters: - self.params += group_param['params'] + if not group_param['params']: + raise ValueError("Optimizer got an empty parameter list.") + + self.group_params += group_param['params'] if 'lr' in group_param.keys(): params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor))