From 85a06b00c6f9a1fafca2c3f183a48983628f39a1 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Wed, 3 Jun 2020 13:17:48 +0800 Subject: [PATCH] add order function in group params --- mindspore/nn/optim/adam.py | 24 ++- mindspore/nn/optim/momentum.py | 24 ++- mindspore/nn/optim/optimizer.py | 67 ++++++- mindspore/nn/optim/rmsprop.py | 24 ++- mindspore/nn/optim/sgd.py | 24 ++- mindspore/train/dataset_helper.py | 6 +- ...e.py => test_optimizer_with_loss_scale.py} | 0 ...> test_optimizer_with_parameter_groups.py} | 170 ++++++++++++++++-- 8 files changed, 289 insertions(+), 50 deletions(-) rename tests/ut/python/optimizer/{test_optimize_with_loss_scale.py => test_optimizer_with_loss_scale.py} (100%) rename tests/ut/python/optimizer/{test_optimize_with_parameter_groups.py => test_optimizer_with_parameter_groups.py} (56%) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index a256f0e0d8..b66bb8b3b9 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -144,10 +144,12 @@ class Adam(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + To improve parameter groups performance, the customized order of parameters can be supported. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", - "lr" and "weight_decay" are the keys can be parsed. + "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. @@ -157,6 +159,11 @@ class Adam(Optimizer): - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' but not in any group will use default learning rate and default weight + decay. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -193,13 +200,16 @@ class Adam(Optimizer): >>> >>> #2) Use parameter groups and set different values >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) - >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, - >>> {'params': no_conv_params}] + >>> bias_params = list(filter(lambda x: 'bias' in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': bias_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] >>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 - >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a - >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. + >>> # The bias_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> # The parameters which in the value of 'order_params' but not in any group will use a learning rate + >>> # of default value 0.1 and a weight decay of default value 0.0. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim) diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 080377b71d..8ff241f4a3 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -45,10 +45,12 @@ class Momentum(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + To improve parameter groups performance, the customized order of parameters can be supported. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", - "lr" and "weight_decay" are the keys can be parsed. + "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. @@ -58,6 +60,11 @@ class Momentum(Optimizer): - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' but not in any group will use default learning rate and default weight + decay. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -86,13 +93,16 @@ class Momentum(Optimizer): >>> >>> #2) Use parameter groups and set different values >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) - >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, - >>> {'params': no_conv_params}] + >>> bias_params = list(filter(lambda x: 'bias' in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': bias_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] >>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) - >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 - >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a - >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. + >>> # The bias_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> # The parameters which in the value of 'order_params' but not in any group will use a learning rate + >>> # of default value 0.1 and a weight decay of default value 0.0. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 658ffb7b46..a860b4eb2c 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -48,6 +48,8 @@ class Optimizer(Cell): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + To improve parameter groups performance, the customized order of parameters can be supported. + Args: learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, @@ -60,7 +62,7 @@ class Optimizer(Cell): converted to float. parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, - the "params", "lr" and "weight_decay" are the keys can be parsed. + the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. @@ -70,6 +72,11 @@ class Optimizer(Cell): - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' but not in any group will use default learning rate and default weight + decay. + weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the @@ -103,6 +110,7 @@ class Optimizer(Cell): self.is_group = False self.is_group_lr = False + self.is_group_params_ordered = False self.loss_scale = loss_scale if isinstance(learning_rate, int): learning_rate = float(learning_rate) @@ -210,9 +218,8 @@ class Optimizer(Cell): raise TypeError("Learning rate should be float, Tensor or Iterable.") return lr - def _init_group_params(self, parameters, learning_rate, weight_decay): - """Init learning rate or weight decay in group params.""" - origin_dynamic_lr = self.dynamic_lr + def _parse_group_params(self, parameters, learning_rate): + """Parse group params.""" if self.dynamic_lr: dynamic_lr_length = learning_rate.size() else: @@ -220,6 +227,15 @@ class Optimizer(Cell): for group_param in parameters: lr_length = dynamic_lr_length + if 'order_params' in group_param.keys(): + if len(group_param.keys()) > 1: + raise ValueError("The order params dict in group parameters should " + "only include the 'order_params' key.") + if not isinstance(group_param['order_params'], Iterable): + raise TypeError("The value of 'order_params' should be an Iterable type.") + self.is_group_params_ordered = True + continue + if 'lr' in group_param.keys(): self.is_group_lr = True self._get_single_lr(group_param['lr']) @@ -229,10 +245,20 @@ class Optimizer(Cell): elif isinstance(group_param['lr'], Tensor): lr_length = group_param['lr'].size() self.dynamic_lr = True + if dynamic_lr_length not in (lr_length, 0): raise ValueError("The dynamic learning rate in group should be the same size.") + + if not group_param['params']: + raise ValueError("Optimizer got an empty group parameter list.") + dynamic_lr_length = lr_length + self.dynamic_lr_length = dynamic_lr_length + def _init_group_params(self, parameters, learning_rate, weight_decay): + """Init learning rate or weight decay in group params.""" + origin_dynamic_lr = self.dynamic_lr + self._parse_group_params(parameters, learning_rate) if self.dynamic_lr and not origin_dynamic_lr: self.gather = P.GatherV2() self.assignadd = P.AssignAdd() @@ -240,20 +266,20 @@ class Optimizer(Cell): params_store = [] for group_param in parameters: - if not group_param['params']: - raise ValueError("Optimizer got an empty parameter list.") + if 'order_params' in group_param.keys(): + ordered_parameters = group_param['order_params'] + continue self.group_params += group_param['params'] if 'lr' in group_param.keys(): params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) - if self.dynamic_lr and not params_dynamic_lr: - lr = Tensor(np.array([group_param['lr']] * dynamic_lr_length).astype(np.float32)) + lr = Tensor(np.array([group_param['lr']] * self.dynamic_lr_length).astype(np.float32)) else: lr = self._get_single_lr(group_param['lr']) else: if self.dynamic_lr and not origin_dynamic_lr: - lr = Tensor(np.array([self.scalar_lr] * dynamic_lr_length).astype(np.float32)) + lr = Tensor(np.array([self.scalar_lr] * self.dynamic_lr_length).astype(np.float32)) else: lr = learning_rate @@ -273,10 +299,33 @@ class Optimizer(Cell): validator.check_value_type("parameter", param, [Parameter], self.cls_name) if param.name in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") + params_store.append(param.name) self.group_lr.append(Parameter(lr, name="lr_" + param.name)) self.group_weight_decay.append(weight_decay_) + if self.is_group_params_ordered: + self._order_and_adjust_group_params(ordered_parameters, learning_rate, weight_decay) + + def _order_and_adjust_group_params(self, ordered_parameters, learning_rate, weight_decay): + """ + Order group parameter, learning rate and weight decay in group params. And assign the parameters + which in the value of 'order_params' but not in any group to default value. + """ + params_length = len(ordered_parameters) + ordered_learning_rate = [Parameter(learning_rate, name="lr_" + param.name) for param in ordered_parameters] + ordered_weight_decay = [weight_decay * self.loss_scale] * params_length + params_name = [param.name for param in ordered_parameters] + + for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay): + index = params_name.index(param.name) + ordered_learning_rate[index] = lr + ordered_weight_decay[index] = wd + + self.group_params = list(ordered_parameters) + self.group_lr = ordered_learning_rate + self.group_weight_decay = ordered_weight_decay + def get_lr(self): """ Get the learning rate of current step. diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 4d572574ae..7fc15868da 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -51,6 +51,8 @@ class RMSProp(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + To improve parameter groups performance, the customized order of parameters can be supported. + Update `params` according to the RMSProp algorithm. The equation is as follows: @@ -93,7 +95,7 @@ class RMSProp(Optimizer): Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", - "lr" and "weight_decay" are the keys can be parsed. + "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. @@ -103,6 +105,11 @@ class RMSProp(Optimizer): - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' but not in any group will use default learning rate and default weight + decay. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -133,13 +140,16 @@ class RMSProp(Optimizer): >>> >>> #2) Use parameter groups and set different values >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) - >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, - >>> {'params': no_conv_params}] + >>> bias_params = list(filter(lambda x: 'bias' in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': bias_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] >>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 - >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a - >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. + >>> # The bias_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> # The parameters which in the value of 'order_params' but not in any group will use a learning rate + >>> # of default value 0.1 and a weight decay of default value 0.0. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index bf49244550..a7493400f8 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -47,10 +47,12 @@ class SGD(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. + To improve parameter groups performance, the customized order of parameters can be supported. + Args: params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", - "lr" and "weight_decay" are the keys can be parsed. + "lr", "weight_decay" and "order_params" are the keys can be parsed. - params: Required. The value should be a list of `Parameter`. @@ -60,6 +62,11 @@ class SGD(Optimizer): - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the `weight_decay` in the API will be used. + - order_params: Optional. If "order_params" in the keys, the value should be the order of parameters and + the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which + in the value of 'order_params' but not in any group will use default learning rate and default weight + decay. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is Iterable or a Tensor and the dims of the Tensor is 1, use dynamic learning rate, then the i-th step will @@ -90,13 +97,16 @@ class SGD(Optimizer): >>> >>> #2) Use parameter groups and set different values >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) - >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, - >>> {'params': no_conv_params}] + >>> bias_params = list(filter(lambda x: 'bias' in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + >>> {'params': bias_params, 'lr': 0.01}, + >>> {'order_params': net.trainable_params()}] >>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) - >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 - >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a - >>> # learning rate of 0.1 and a weight decay of 0.0. + >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. + >>> # The bias_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> # The parameters which in the value of 'order_params' but not in any group will use a learning rate + >>> # of default value 0.1 and a weight decay of default value 0.0. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 52797b631c..083349e5a1 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Dataset help for minddata dataset""" +import math + from mindspore._checkparam import check_bool from .. import context from .parallel_utils import ParallelMode @@ -104,10 +106,10 @@ class _DatasetIter: loop_count = 1 if hasattr(dataset, '__loop_size__'): loop_size = dataset.__loop_size__ - if dataset.get_dataset_size() % loop_size != 0: + if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' f'loop_size {loop_size} are not matched.') - loop_count = int(dataset.get_dataset_size() / loop_size) + loop_count = math.ceil(dataset.get_dataset_size() / loop_size) return loop_count diff --git a/tests/ut/python/optimizer/test_optimize_with_loss_scale.py b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py similarity index 100% rename from tests/ut/python/optimizer/test_optimize_with_loss_scale.py rename to tests/ut/python/optimizer/test_optimizer_with_loss_scale.py diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py similarity index 56% rename from tests/ut/python/optimizer/test_optimize_with_parameter_groups.py rename to tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py index 96c3c936b2..05e58013fa 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimizer_with_parameter_groups.py @@ -60,8 +60,9 @@ def test_group_lr(): default_lr = 0.1 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - group_params = [{'params': conv_params, 'lr': conv_lr}, - {'params': no_conv_params}] + group_params = [{'params': no_conv_params}, + {'params': conv_params, 'lr': conv_lr}, + {'order_params': net.trainable_params()}] net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() @@ -69,12 +70,15 @@ def test_group_lr(): assert opt.is_group is True assert opt.is_group_lr is True assert opt.dynamic_lr is False - for lr, param in zip(opt.learning_rate, opt.parameters): + assert opt.is_group_params_ordered is True + for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): if param in conv_params: assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) else: assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) + assert param.name == order_param.name + net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) _executor.compile(train_network, inputs, label) @@ -89,20 +93,24 @@ def test_group_dynamic_1(): default_lr = (0.1, 0.2, 0.3) conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - group_params = [{'params': conv_params, 'lr': conv_lr}, - {'params': no_conv_params}] + group_params = [{'params': no_conv_params}, + {'params': conv_params, 'lr': conv_lr}, + {'order_params': net.trainable_params()}] net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) assert opt.is_group is True assert opt.dynamic_lr is True - for lr, param in zip(opt.learning_rate, opt.parameters): + assert opt.is_group_params_ordered is True + for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): if param in conv_params: assert np.all(lr.data.asnumpy() == Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy()) else: assert np.all(lr.data.asnumpy() == Tensor(np.array(list(default_lr)).astype(np.float32)).asnumpy()) + assert param.name == order_param.name + net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) _executor.compile(train_network, inputs, label) @@ -127,9 +135,9 @@ def test_group_dynamic_2(): assert opt.dynamic_lr is True for lr, param in zip(opt.learning_rate, opt.parameters): if param in conv_params: - assert np.all(lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32))) + assert np.all(lr.data.asnumpy() == Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy()) else: - assert np.all(lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32))) + assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3).astype(np.float32)).asnumpy()) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) @@ -180,15 +188,18 @@ def test_weight_decay(): default_weight_decay = 0.0 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) - group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, - {'params': no_conv_params}] + group_params = [{'params': no_conv_params}, + {'params': conv_params, 'weight_decay': conv_weight_decay}, + {'order_params': net.trainable_params()}] net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) assert opt.is_group is True assert opt.is_group_lr is False - for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): + assert opt.is_group_params_ordered is True + for weight_decay, decay_flags, param, order_param in zip( + opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()): if param in conv_params: assert weight_decay == conv_weight_decay assert decay_flags is True @@ -196,6 +207,8 @@ def test_weight_decay(): assert weight_decay == default_weight_decay assert decay_flags is False + assert param.name == order_param.name + net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, opt) _executor.compile(train_network, inputs, label) @@ -233,6 +246,19 @@ def test_get_lr_parameter_with_group(): assert lr.name == 'lr_' + param.name +def test_get_lr_parameter_with_order_group(): + net = LeNet5() + conv_lr = 0.1 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'order_params': net.trainable_params()}] + opt = SGD(group_params) + assert opt.is_group_lr is True + for param in opt.parameters: + lr = opt.get_lr_parameter(param) + assert lr.name == 'lr_' + param.name + + def test_get_lr_parameter_with_no_group(): net = LeNet5() conv_weight_decay = 0.8 @@ -250,3 +276,125 @@ def test_get_lr_parameter_with_no_group(): params_error = [1, 2, 3] with pytest.raises(TypeError): opt.get_lr_parameter(params_error) + + +def test_order_params_lr(): + net = LeNet5() + conv_lr = 0.01 + default_lr = 0.1 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'order_params': net.trainable_params()}] + opt = SGD(group_params, learning_rate=default_lr) + assert opt.is_group is True + assert opt.is_group_lr is True + assert opt.is_group_params_ordered is True + for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()): + if param in conv_params: + assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy()) + else: + assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) + + assert param.name == order_param.name + assert lr.name == 'lr_' + param.name + + +def test_order_params_weight_decay(): + net = LeNet5() + conv_weight_decay = 0.01 + default_wd = 0.0 + default_lr = 0.1 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'order_params': net.trainable_params()}] + opt = SGD(group_params, learning_rate=default_lr, weight_decay=default_wd) + assert opt.is_group is True + assert opt.is_group_lr is False + assert opt.is_group_params_ordered is True + assert opt.learning_rate.name == "learning_rate" + assert np.all(opt.learning_rate.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy()) + for weight_decay, decay_flags, param, order_param in zip( + opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()): + if param in conv_params: + assert weight_decay == conv_weight_decay + assert decay_flags is True + else: + assert weight_decay == default_wd + assert decay_flags is False + assert param.name == order_param.name + + +def test_order_params_all_1(): + net = LeNet5() + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + bias_params = list(filter(lambda x: 'bias' in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': 0.01}, + {'params': bias_params, 'lr': 0.01}, + {'order_params': net.trainable_params()}] + opt = SGD(group_params, learning_rate=0.1, weight_decay=0.0) + assert opt.is_group is True + assert opt.is_group_lr is True + assert opt.is_group_params_ordered is True + for weight_decay, decay_flags, lr, param, order_param in zip( + opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, net.trainable_params()): + if param in conv_params: + assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy()) + assert weight_decay == 0.01 + assert decay_flags is True + elif param in bias_params: + assert np.all(lr.data.asnumpy() == Tensor(0.01, mstype.float32).asnumpy()) + assert weight_decay == 0.0 + assert decay_flags is False + else: + assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy()) + assert weight_decay == 0.0 + assert decay_flags is False + + assert param.name == order_param.name + assert lr.name == 'lr_' + param.name + + +def test_order_params_all_2(): + net = LeNet5() + conv_weight_decay = 0.01 + fc1_lr = (0.5, 0.4, 0.3) + default_lr = 0.1 + default_wd = 0.0 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + fc1_params = list(filter(lambda x: 'fc1' in x.name, net.trainable_params())) + group_params = [{'params': fc1_params, 'lr': fc1_lr}, + {'params': conv_params, 'weight_decay': conv_weight_decay}, + {'order_params': net.trainable_params()}] + opt = SGD(group_params, learning_rate=default_lr, weight_decay=default_wd) + assert opt.is_group is True + assert opt.is_group_lr is True + assert opt.is_group_params_ordered is True + for weight_decay, decay_flags, lr, param, order_param in zip( + opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, net.trainable_params()): + if param in conv_params: + assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy()) + assert weight_decay == conv_weight_decay + assert decay_flags is True + elif param in fc1_params: + assert np.all(lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy()) + assert weight_decay == default_wd + assert decay_flags is False + else: + assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy()) + assert weight_decay == default_wd + assert decay_flags is False + + assert param.name == order_param.name + assert lr.name == 'lr_' + param.name + + +def test_get_order_params_with_not_include(): + net = LeNet5() + conv_weight_decay = 0.8 + + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'order_params': no_conv_params}] + with pytest.raises(ValueError): + SGD(group_params)