| @@ -243,7 +243,7 @@ class Adam(Optimizer): | |||||
| self.beta1_power = beta1_power | self.beta1_power = beta1_power | ||||
| beta2_power = self.beta2_power * self.beta2 | beta2_power = self.beta2_power * self.beta2 | ||||
| self.beta2_power = beta2_power | self.beta2_power = beta2_power | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | ||||
| self.beta2, self.eps), | self.beta2, self.eps), | ||||
| lr, gradients, params, moment1, moment2) | lr, gradients, params, moment1, moment2) | ||||
| @@ -111,7 +111,7 @@ class Momentum(Optimizer): | |||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.scale_grad(gradients) | gradients = self.scale_grad(gradients) | ||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) | success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) | ||||
| else: | else: | ||||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) | success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) | ||||
| @@ -94,6 +94,7 @@ class Optimizer(Cell): | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | ||||
| self.is_group = False | self.is_group = False | ||||
| self.is_group_lr = False | |||||
| self.loss_scale = loss_scale | self.loss_scale = loss_scale | ||||
| if isinstance(learning_rate, float): | if isinstance(learning_rate, float): | ||||
| self.dynamic_lr = False | self.dynamic_lr = False | ||||
| @@ -116,14 +117,17 @@ class Optimizer(Cell): | |||||
| self.group_weight_decay = [] | self.group_weight_decay = [] | ||||
| self._init_group_params(parameters, learning_rate, weight_decay) | self._init_group_params(parameters, learning_rate, weight_decay) | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| self.learning_rate = ParameterTuple(self.group_lr) | self.learning_rate = ParameterTuple(self.group_lr) | ||||
| else: | |||||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||||
| if self.is_group: | |||||
| self.parameters = ParameterTuple(self.params) | self.parameters = ParameterTuple(self.params) | ||||
| self.weight_decay = tuple(self.group_weight_decay) | self.weight_decay = tuple(self.group_weight_decay) | ||||
| decay_filter = lambda x: x > 0 | decay_filter = lambda x: x > 0 | ||||
| self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | ||||
| else: | else: | ||||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||||
| self.parameters = ParameterTuple(parameters) | self.parameters = ParameterTuple(parameters) | ||||
| self.weight_decay = weight_decay * loss_scale | self.weight_decay = weight_decay * loss_scale | ||||
| decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | ||||
| @@ -207,6 +211,7 @@ class Optimizer(Cell): | |||||
| for group_param in parameters: | for group_param in parameters: | ||||
| lr_length = dynamic_lr_length | lr_length = dynamic_lr_length | ||||
| if 'lr' in group_param.keys(): | if 'lr' in group_param.keys(): | ||||
| self.is_group_lr = True | |||||
| self._get_single_lr(group_param['lr']) | self._get_single_lr(group_param['lr']) | ||||
| if isinstance(group_param['lr'], Iterable): | if isinstance(group_param['lr'], Iterable): | ||||
| lr_length = len(group_param['lr']) | lr_length = len(group_param['lr']) | ||||
| @@ -247,6 +252,10 @@ class Optimizer(Cell): | |||||
| else: | else: | ||||
| weight_decay_ = weight_decay * self.loss_scale | weight_decay_ = weight_decay * self.loss_scale | ||||
| for key in group_param.keys(): | |||||
| if key not in ('params', 'lr', 'weight_decay'): | |||||
| logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") | |||||
| for param in group_param['params']: | for param in group_param['params']: | ||||
| if param in params_store: | if param in params_store: | ||||
| raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") | raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") | ||||
| @@ -261,7 +270,7 @@ class Optimizer(Cell): | |||||
| Returns: | Returns: | ||||
| float, the learning rate of current step. | float, the learning rate of current step. | ||||
| """ | """ | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| lr = self.learning_rate | lr = self.learning_rate | ||||
| if self.dynamic_lr: | if self.dynamic_lr: | ||||
| lr = () | lr = () | ||||
| @@ -176,7 +176,7 @@ class RMSProp(Optimizer): | |||||
| gradients = self.scale_grad(gradients) | gradients = self.scale_grad(gradients) | ||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.centered: | if self.centered: | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | ||||
| self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) | self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) | ||||
| else: | else: | ||||
| @@ -184,7 +184,7 @@ class RMSProp(Optimizer): | |||||
| self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) | self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) | ||||
| else: | else: | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | ||||
| self.momentum), lr, params, self.ms, self.moment, gradients) | self.momentum), lr, params, self.ms, self.moment, gradients) | ||||
| else: | else: | ||||
| @@ -139,7 +139,7 @@ class SGD(Optimizer): | |||||
| gradients = self.decay_weight(gradients) | gradients = self.decay_weight(gradients) | ||||
| gradients = self.scale_grad(gradients) | gradients = self.scale_grad(gradients) | ||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| if self.is_group: | |||||
| if self.is_group_lr: | |||||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | ||||
| else: | else: | ||||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) | success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) | ||||
| @@ -65,12 +65,13 @@ def test_group_lr(): | |||||
| opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) | opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) | ||||
| assert opt.is_group is True | assert opt.is_group is True | ||||
| assert opt.is_group_lr is True | |||||
| assert opt.dynamic_lr is False | assert opt.dynamic_lr is False | ||||
| for lr, param in zip(opt.learning_rate, opt.parameters): | for lr, param in zip(opt.learning_rate, opt.parameters): | ||||
| if param in conv_params: | if param in conv_params: | ||||
| assert lr.data == Tensor(conv_lr, mstype.float32) | |||||
| assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) | |||||
| else: | else: | ||||
| assert lr.data == Tensor(default_lr, mstype.float32) | |||||
| assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, opt) | train_network = TrainOneStepCell(net_with_loss, opt) | ||||
| @@ -96,9 +97,9 @@ def test_group_dynamic_1(): | |||||
| assert opt.dynamic_lr is True | assert opt.dynamic_lr is True | ||||
| for lr, param in zip(opt.learning_rate, opt.parameters): | for lr, param in zip(opt.learning_rate, opt.parameters): | ||||
| if param in conv_params: | if param in conv_params: | ||||
| assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) | |||||
| assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) | |||||
| else: | else: | ||||
| assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) | |||||
| assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, opt) | train_network = TrainOneStepCell(net_with_loss, opt) | ||||
| @@ -124,9 +125,9 @@ def test_group_dynamic_2(): | |||||
| assert opt.dynamic_lr is True | assert opt.dynamic_lr is True | ||||
| for lr, param in zip(opt.learning_rate, opt.parameters): | for lr, param in zip(opt.learning_rate, opt.parameters): | ||||
| if param in conv_params: | if param in conv_params: | ||||
| assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) | |||||
| assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))) | |||||
| else: | else: | ||||
| assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) | |||||
| assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))) | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, opt) | train_network = TrainOneStepCell(net_with_loss, opt) | ||||
| @@ -184,6 +185,7 @@ def test_weight_decay(): | |||||
| opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) | opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) | ||||
| assert opt.is_group is True | assert opt.is_group is True | ||||
| assert opt.is_group_lr is False | |||||
| for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): | for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): | ||||
| if param in conv_params: | if param in conv_params: | ||||
| assert weight_decay == conv_weight_decay | assert weight_decay == conv_weight_decay | ||||