| @@ -0,0 +1,300 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """dynamic learning rate""" | |||
| import math | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Rel | |||
| def piecewise_constant_lr(milestone, learning_rates): | |||
| r""" | |||
| Get piecewise constant learning rate. | |||
| Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be | |||
| :math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of | |||
| `milestone`. Let the output learning rate be `y`. | |||
| .. math:: | |||
| y[i] = x_t for i \in [M_{t-1}, M_t) | |||
| Args: | |||
| milestone (list[int]): A list of milestone. This list is a monotone increasing list. | |||
| learning_rates (list[float]): A list of learning rates. | |||
| Returns: | |||
| list[float]. The size of list is :math:`M_N`. | |||
| Examples: | |||
| >>> milestone = [2, 5, 10] | |||
| >>> learning_rates = [0.1, 0.05, 0.01] | |||
| >>> lr = piecewise_constant_lr(milestone, learning_rates) | |||
| [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] | |||
| """ | |||
| validator.check_type('milestone', milestone, (tuple, list)) | |||
| validator.check_type('learning_rates', learning_rates, (tuple, list)) | |||
| if len(milestone) != len(learning_rates): | |||
| raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') | |||
| lr = [] | |||
| last_item = 0 | |||
| for i, item in enumerate(milestone): | |||
| validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) | |||
| validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) | |||
| if item < last_item: | |||
| raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | |||
| lr += [learning_rates[i]] * (item - last_item) | |||
| last_item = item | |||
| return lr | |||
| def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| validator.check_float_positive('learning_rate', learning_rate) | |||
| validator.check_float_positive('decay_rate', decay_rate) | |||
| validator.check_type('is_stair', is_stair, [bool]) | |||
| def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | |||
| r""" | |||
| Calculate learning rate base on exponential decay function. | |||
| For the i-th step, the formula of computing decayed_learning_rate[i] is: | |||
| .. math:: | |||
| decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}} | |||
| Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`. | |||
| Args: | |||
| learning_rate (float): The initial value of learning rate. | |||
| decay_rate (float): The decay rate. | |||
| total_step (int): The total number of steps. | |||
| step_per_epoch (int): The number of steps in per epoch. | |||
| decay_epoch (int): A value used to calculate decayed learning rate. | |||
| is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False. | |||
| Returns: | |||
| list[float]. The size of list is `total_step`. | |||
| Examples: | |||
| >>> learning_rate = 0.1 | |||
| >>> decay_rate = 0.9 | |||
| >>> total_step = 6 | |||
| >>> step_per_epoch = 2 | |||
| >>> decay_epoch = 1 | |||
| >>> lr = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| [0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002] | |||
| """ | |||
| _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | |||
| lr = [] | |||
| for i in range(total_step): | |||
| if is_stair: | |||
| lr.append(learning_rate * decay_rate ** math.floor(math.floor(i / step_per_epoch) / decay_epoch)) | |||
| else: | |||
| lr.append(learning_rate * decay_rate ** (math.floor(i / step_per_epoch) / decay_epoch)) | |||
| return lr | |||
| def natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | |||
| r""" | |||
| Calculate learning rate base on natural exponential decay function. | |||
| For the i-th step, the formula of computing decayed_learning_rate[i] is: | |||
| .. math:: | |||
| decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch} | |||
| Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`. | |||
| Args: | |||
| learning_rate (float): The initial value of learning rate. | |||
| decay_rate (float): The decay rate. | |||
| total_step (int): The total number of steps. | |||
| step_per_epoch (int): The number of steps in per epoch. | |||
| decay_epoch (int): A value used to calculate decayed learning rate. | |||
| is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False. | |||
| Returns: | |||
| list[float]. The size of list is `total_step`. | |||
| Examples: | |||
| >>> learning_rate = 0.1 | |||
| >>> decay_rate = 0.9 | |||
| >>> total_step = 6 | |||
| >>> step_per_epoch = 2 | |||
| >>> decay_epoch = 2 | |||
| >>> lr = natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True) | |||
| [0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657] | |||
| """ | |||
| _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | |||
| function = lambda x, y: x | |||
| if is_stair: | |||
| function = lambda x, y: math.floor(x / y) * y | |||
| lr = [] | |||
| for i in range(total_step): | |||
| lr.append(learning_rate * math.e ** (-decay_rate * function(math.floor(i / step_per_epoch), decay_epoch))) | |||
| return lr | |||
| def inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | |||
| r""" | |||
| Calculate learning rate base on inverse-time decay function. | |||
| For the i-th step, the formula of computing decayed_learning_rate[i] is: | |||
| .. math:: | |||
| decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch) | |||
| Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`. | |||
| Args: | |||
| learning_rate (float): The initial value of learning rate. | |||
| decay_rate (float): The decay rate. | |||
| total_step (int): The total number of steps. | |||
| step_per_epoch (int): The number of steps in per epoch. | |||
| decay_epoch (int): A value used to calculate decayed learning rate. | |||
| is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False. | |||
| Returns: | |||
| list[float]. The size of list is `total_step`. | |||
| Examples: | |||
| >>> learning_rate = 0.1 | |||
| >>> decay_rate = 0.5 | |||
| >>> total_step = 6 | |||
| >>> step_per_epoch = 1 | |||
| >>> decay_epoch = 1 | |||
| >>> lr = inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True) | |||
| [0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574] | |||
| """ | |||
| _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | |||
| lr = [] | |||
| for i in range(total_step): | |||
| if is_stair: | |||
| lr.append(learning_rate / (1 + decay_rate * math.floor(math.floor(i / step_per_epoch) / decay_epoch))) | |||
| else: | |||
| lr.append(learning_rate / (1 + decay_rate * math.floor(i / step_per_epoch) / decay_epoch)) | |||
| return lr | |||
| def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||
| r""" | |||
| Calculate learning rate base on cosine decay function. | |||
| For the i-th step, the formula of computing decayed_learning_rate[i] is: | |||
| .. math:: | |||
| decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) * | |||
| (1 + cos(\frac{current\_epoch}{decay\_epoch}\pi)) | |||
| Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`. | |||
| Args: | |||
| min_lr (float): The minimum value of learning rate. | |||
| max_lr (float): The maximum value of learning rate. | |||
| total_step (int): The total number of steps. | |||
| step_per_epoch (int): The number of steps in per epoch. | |||
| decay_epoch (int): A value used to calculate decayed learning rate. | |||
| Returns: | |||
| list[float]. The size of list is `total_step`. | |||
| Examples: | |||
| >>> min_lr = 0.01 | |||
| >>> max_lr = 0.1 | |||
| >>> total_step = 6 | |||
| >>> step_per_epoch = 2 | |||
| >>> decay_epoch = 2 | |||
| >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) | |||
| [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] | |||
| """ | |||
| validator.check_float_positive('min_lr', min_lr) | |||
| validator.check_float_positive('max_lr', max_lr) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| delta = 0.5 * (max_lr - min_lr) | |||
| lr = [] | |||
| for i in range(total_step): | |||
| tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch) | |||
| lr.append(min_lr + delta * (1 + math.cos(math.pi * tmp_epoch / decay_epoch))) | |||
| return lr | |||
| def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, | |||
| update_decay_epoch=False): | |||
| r""" | |||
| Calculate learning rate base on polynomial decay function. | |||
| For the i-th step, the formula of computing decayed_learning_rate[i] is: | |||
| .. math:: | |||
| decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * | |||
| (1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate | |||
| Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`. | |||
| If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is | |||
| :math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` | |||
| Args: | |||
| learning_rate (float): The initial value of learning rate. | |||
| end_learning_rate (float): The end value of learning rate. | |||
| total_step (int): The total number of steps. | |||
| step_per_epoch (int): The number of steps in per epoch. | |||
| decay_epoch (int): A value used to calculate decayed learning rate. | |||
| power (float): A value used to calculate decayed learning rate. | |||
| update_decay_epoch (bool): If true, update `decay_epoch`. Default: False. | |||
| Returns: | |||
| list[float]. The size of list is `total_step`. | |||
| Examples: | |||
| >>> learning_rate = 0.1 | |||
| >>> end_learning_rate = 0.01 | |||
| >>> total_step = 6 | |||
| >>> step_per_epoch = 2 | |||
| >>> decay_epoch = 2 | |||
| >>> power = 0.5 | |||
| >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | |||
| """ | |||
| validator.check_float_positive('learning_rate', learning_rate) | |||
| validator.check_float_positive('end_learning_rate', end_learning_rate) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| validator.check_type('power', power, [float]) | |||
| validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) | |||
| function = lambda x, y: (x, min(x, y)) | |||
| if update_decay_epoch: | |||
| function = lambda x, y: (x * max(math.ceil(y / x), 1), y) | |||
| lr = [] | |||
| delta = learning_rate - end_learning_rate | |||
| for i in range(total_step): | |||
| current_epoch = math.floor(i / step_per_epoch) | |||
| decay_epoch, tmp_epoch = function(decay_epoch, current_epoch) | |||
| lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate) | |||
| return lr | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """adam""" | |||
| from typing import Iterable | |||
| import numpy as np | |||
| from mindspore.common import dtype as mstype | |||
| @@ -25,7 +24,7 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer, apply_decay, grad_scale | |||
| from .optimizer import Optimizer | |||
| _learning_rate_update_func = ['linear', 'cos', 'sin'] | |||
| @@ -168,22 +167,13 @@ class Adam(Optimizer): | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Adam, self).__init__(learning_rate, params) | |||
| super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||
| validator.check_type("use_locking", use_locking, [bool]) | |||
| validator.check_type("use_nesterov", use_nesterov, [bool]) | |||
| validator.check_type("loss_scale", loss_scale, [float]) | |||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) | |||
| self.dynamic_lr = False | |||
| if isinstance(learning_rate, Iterable) or \ | |||
| (isinstance(learning_rate, Tensor) and learning_rate.dim() == 1): | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| self.beta1 = Tensor(beta1, mstype.float32) | |||
| self.beta2 = Tensor(beta2, mstype.float32) | |||
| self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") | |||
| @@ -196,8 +186,6 @@ class Adam(Optimizer): | |||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.Adam(use_locking, use_nesterov) | |||
| self.weight_decay = weight_decay * loss_scale | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.pow = P.Pow() | |||
| self.sqrt = P.Sqrt() | |||
| @@ -208,15 +196,9 @@ class Adam(Optimizer): | |||
| params = self.parameters | |||
| moment1 = self.moment1 | |||
| moment2 = self.moment2 | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| lr = self.learning_rate | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, self.axis) | |||
| F.control_depend(lr, self.assignadd(self.global_step, self.one)) | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| beta1_power = self.beta1_power * self.beta1 | |||
| self.beta1_power = beta1_power | |||
| @@ -13,14 +13,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """momentum""" | |||
| from typing import Iterable | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common import Tensor | |||
| from .optimizer import Optimizer, apply_decay, grad_scale | |||
| from .optimizer import Optimizer | |||
| momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @@ -88,43 +83,20 @@ class Momentum(Optimizer): | |||
| """ | |||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Momentum, self).__init__(learning_rate, params) | |||
| super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| if isinstance(learning_rate, Iterable) or \ | |||
| (isinstance(learning_rate, Tensor) and learning_rate.dim() == 1): | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| else: | |||
| self.dynamic_lr = False | |||
| self.gather = None | |||
| self.assignadd = None | |||
| self.global_step = None | |||
| self.axis = None | |||
| self.momentum = Parameter(momentum, name="momentum") | |||
| self.params = self.parameters | |||
| self.moments = self.params.clone(prefix="moments", init='zeros') | |||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.ApplyMomentum() | |||
| self.weight_decay = weight_decay * loss_scale | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.one = Tensor(1, mstype.int32) | |||
| def construct(self, gradients): | |||
| params = self.params | |||
| moments = self.moments | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, self.axis) | |||
| F.control_depend(lr, self.assignadd(self.global_step, self.one)) | |||
| else: | |||
| lr = self.learning_rate | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) | |||
| return success | |||
| @@ -17,9 +17,11 @@ from typing import Iterable | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common.tensor import Tensor | |||
| @@ -42,34 +44,110 @@ class Optimizer(Cell): | |||
| Args: | |||
| learning_rate (float): A floating point value for the learning rate. Should be greater than 0. | |||
| parameters (list): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| should be class mindspore.Parameter. | |||
| weight_decay (float): A floating point value for the weight decay. Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda | |||
| x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| Raises: | |||
| ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1. | |||
| TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable. | |||
| """ | |||
| def __init__(self, learning_rate, parameters): | |||
| def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Optimizer, self).__init__() | |||
| if isinstance(learning_rate, float): | |||
| self.dynamic_lr = False | |||
| self.gather = None | |||
| self.assignadd = None | |||
| self.global_step = None | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) | |||
| elif isinstance(learning_rate, Iterable): | |||
| learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32)) | |||
| elif isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() > 1: | |||
| raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," | |||
| f"but got {learning_rate.dim()}.") | |||
| else: | |||
| raise TypeError("Learning rate should be float, Tensor or Iterable.") | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') | |||
| if isinstance(learning_rate, Iterable): | |||
| learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32)) | |||
| elif isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() > 1: | |||
| raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," | |||
| f"but got {learning_rate.dim()}.") | |||
| if learning_rate.dim() == 1 and learning_rate.size() < 2: | |||
| logger.warning("If want to use the dynamic learning rate, please make sure that the number " | |||
| "of elements in the list, tuple or tensor passed is greater than 1.") | |||
| else: | |||
| raise TypeError("Learning rate should be float, Tensor or Iterable.") | |||
| if loss_scale <= 0.0: | |||
| raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) | |||
| if weight_decay < 0.0: | |||
| raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay)) | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1 and learning_rate.size() < 2: | |||
| logger.warning("If want to use the dynamic learning rate, please make sure that " | |||
| "the number of elements in the list, tuple or tensor passed is greater than 1.") | |||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||
| self.parameters = ParameterTuple(parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.weight_decay = weight_decay * loss_scale | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | |||
| if not self.parameters: | |||
| raise ValueError("optimizer got an empty parameter list.") | |||
| def decay_weight(self, gradients): | |||
| """ | |||
| Weight decay. | |||
| An approach to reduce the overfitting of a deep learning neural network model. | |||
| Args: | |||
| gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with | |||
| `self.parameters`. | |||
| Returns: | |||
| tuple[Tensor], The gradients after weight decay. | |||
| """ | |||
| if self.weight_decay > 0: | |||
| params = self.params | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) | |||
| return gradients | |||
| def scale_grad(self, gradients): | |||
| """ | |||
| Loss scale for mixed precision. | |||
| An approach of mixed precision training to improve the speed and energy efficiency of training deep neural | |||
| network. | |||
| Args: | |||
| gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with | |||
| `self.parameters`. | |||
| Returns: | |||
| tuple[Tensor], The gradients after loss scale. | |||
| """ | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| return gradients | |||
| def get_lr(self): | |||
| """ | |||
| Get the learning rate of current step. | |||
| Returns: | |||
| float, the learning rate of current step. | |||
| """ | |||
| lr = self.learning_rate | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, 0) | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| return lr | |||
| def construct(self, *hyper_params): | |||
| raise NotImplementedError | |||
| @@ -14,12 +14,8 @@ | |||
| # ============================================================================ | |||
| """rmsprop""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common import Tensor | |||
| from .optimizer import Optimizer, grad_scale, apply_decay | |||
| from .optimizer import Optimizer | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @@ -138,7 +134,7 @@ class RMSProp(Optimizer): | |||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(RMSProp, self).__init__(learning_rate, params) | |||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| @@ -157,15 +153,6 @@ class RMSProp(Optimizer): | |||
| else: | |||
| self.opt = P.ApplyRMSProp(use_locking) | |||
| self.dynamic_lr = False | |||
| if not isinstance(learning_rate, float): | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.momentum = momentum | |||
| self.ms = self.parameters.clone(prefix="mean_square", init='zeros') | |||
| @@ -173,21 +160,12 @@ class RMSProp(Optimizer): | |||
| self.hyper_map = C.HyperMap() | |||
| self.decay = decay | |||
| self.decay_tf = tuple(decay_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.weight_decay = weight_decay * loss_scale | |||
| def construct(self, gradients): | |||
| params = self.parameters | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, self.axis) | |||
| F.control_depend(lr, self.assignadd(self.global_step, self.one)) | |||
| else: | |||
| lr = self.learning_rate | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.centered: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | |||
| self.momentum), params, self.mg, self.ms, self.moment, gradients) | |||
| @@ -14,11 +14,9 @@ | |||
| # ============================================================================ | |||
| """sgd""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| import mindspore.common.dtype as mstype | |||
| from .optimizer import Optimizer, grad_scale | |||
| from .optimizer import Optimizer | |||
| sgd_opt = C.MultitypeFuncGraph("sgd_opt") | |||
| @@ -83,7 +81,7 @@ class SGD(Optimizer): | |||
| def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, | |||
| loss_scale=1.0): | |||
| super(SGD, self).__init__(learning_rate, params) | |||
| super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| @@ -92,44 +90,22 @@ class SGD(Optimizer): | |||
| raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) | |||
| self.dampening = dampening | |||
| if weight_decay < 0.0: | |||
| raise ValueError("weight_decay should be at least 0.0, but got weight_decay {}".format(weight_decay)) | |||
| self.weight_decay = weight_decay | |||
| validator.check_type("nesterov", nesterov, [bool]) | |||
| self.nesterov = nesterov | |||
| self.opt = P.SGD(dampening, weight_decay, nesterov) | |||
| self.dynamic_lr = False | |||
| self.gather = None | |||
| self.global_step = None | |||
| self.axis = None | |||
| if not isinstance(learning_rate, float): | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step") | |||
| self.axis = 0 | |||
| self.momentum = Parameter(momentum, name="momentum") | |||
| self.params = self.parameters | |||
| self.accum = self.params.clone(prefix="accum", init='zeros') | |||
| self.stat = self.params.clone(prefix="stat", init='ones') | |||
| self.accum = self.parameters.clone(prefix="accum", init='zeros') | |||
| self.stat = self.parameters.clone(prefix="stat", init='ones') | |||
| self.hyper_map = C.HyperMap() | |||
| self.weight_decay = weight_decay * loss_scale | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| def construct(self, gradients): | |||
| params = self.params | |||
| params = self.parameters | |||
| accum = self.accum | |||
| stat = self.stat | |||
| if self.reciprocal_scale != 1.0: | |||
| gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, self.axis) | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| else: | |||
| lr = self.learning_rate | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) | |||
| return success | |||
| @@ -15,17 +15,11 @@ | |||
| """ test optimizer """ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||
| from mindspore import Tensor | |||
| from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR | |||
| from mindspore.common.parameter import Parameter | |||
| gradient = Tensor(np.zeros([1, 2, 3])) | |||
| accumulation = gradient | |||
| variable = accumulation | |||
| paramsTensor = Tensor(np.zeros([1, 2, 3])) | |||
| class IterableObjc: | |||
| def __iter__(self): | |||
| cont = 0 | |||
| @@ -56,6 +50,7 @@ class TestAdam(): | |||
| def test_construct(self): | |||
| with pytest.raises(TypeError): | |||
| gradient = Tensor(np.zeros([1, 2, 3])) | |||
| adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0) | |||
| adam.construct(gradient) | |||
| @@ -105,4 +100,5 @@ class TestUnsupportParam(): | |||
| def test_Sgd_init(self): | |||
| with pytest.raises(TypeError): | |||
| paramsTensor = Tensor(np.zeros([1, 2, 3])) | |||
| SGD(paramsTensor) | |||
| @@ -0,0 +1,234 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ Test Dynamic Learning Rate """ | |||
| import pytest | |||
| import mindspore | |||
| from mindspore.nn import dynamic_lr as dr | |||
| milestone = [10, 20, 30] | |||
| learning_rates = [0.1, 0.05, 0.01] | |||
| learning_rate = 0.1 | |||
| end_learning_rate = 0.01 | |||
| decay_rate = 0.9 | |||
| total_step = 30 | |||
| step_per_epoch = 3 | |||
| decay_epoch = 2 | |||
| min_lr = 0.01 | |||
| max_lr = 0.1 | |||
| power = 0.5 | |||
| class TestInputs: | |||
| def test_milestone1(self): | |||
| milestone1 = 1 | |||
| with pytest.raises(ValueError): | |||
| dr.piecewise_constant_lr(milestone1, learning_rates) | |||
| def test_milestone2(self): | |||
| milestone1 = [20, 10, 1] | |||
| with pytest.raises(ValueError): | |||
| dr.piecewise_constant_lr(milestone1, learning_rates) | |||
| milestone2 = [1.0, 2.0, True] | |||
| with pytest.raises(ValueError): | |||
| dr.piecewise_constant_lr(milestone2, learning_rates) | |||
| def test_learning_rates1(self): | |||
| lr = True | |||
| with pytest.raises(ValueError): | |||
| dr.piecewise_constant_lr(milestone, lr) | |||
| def test_learning_rates2(self): | |||
| lr = [1, 2, 1] | |||
| with pytest.raises(ValueError): | |||
| dr.piecewise_constant_lr(milestone, lr) | |||
| def test_learning_rate_type(self): | |||
| lr = True | |||
| with pytest.raises(TypeError): | |||
| dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| with pytest.raises(TypeError): | |||
| dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| def test_learning_rate_value(self): | |||
| lr = -1.0 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| def test_end_learning_rate_type(self): | |||
| lr = True | |||
| with pytest.raises(TypeError): | |||
| dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power) | |||
| def test_end_learning_rate_value(self): | |||
| lr = -1.0 | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power) | |||
| def test_decay_rate_type(self): | |||
| rate = 'a' | |||
| with pytest.raises(TypeError): | |||
| dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch) | |||
| def test_decay_rate_value(self): | |||
| rate = -1.0 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch) | |||
| def test_total_step1(self): | |||
| total_step1 = 2.0 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power) | |||
| def test_total_step2(self): | |||
| total_step1 = -1 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power) | |||
| def test_step_per_epoch1(self): | |||
| step_per_epoch1 = True | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power) | |||
| def test_step_per_epoch2(self): | |||
| step_per_epoch1 = -1 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power) | |||
| def test_decay_epoch1(self): | |||
| decay_epoch1 = 'm' | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power) | |||
| def test_decay_epoch2(self): | |||
| decay_epoch1 = -1 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1) | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1) | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power) | |||
| def test_is_stair(self): | |||
| is_stair = 1 | |||
| with pytest.raises(ValueError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | |||
| def test_min_lr_type(self): | |||
| min_lr1 = True | |||
| with pytest.raises(TypeError): | |||
| dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch) | |||
| def test_min_lr_value(self): | |||
| min_lr1 = -1.0 | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch) | |||
| def test_max_lr_type(self): | |||
| max_lr1 = 'a' | |||
| with pytest.raises(TypeError): | |||
| dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch) | |||
| def test_max_lr_value(self): | |||
| max_lr1 = -1.0 | |||
| with pytest.raises(ValueError): | |||
| dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch) | |||
| def test_power(self): | |||
| power1 = True | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) | |||
| def test_update_decay_epoch(self): | |||
| update_decay_epoch = 1 | |||
| with pytest.raises(ValueError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, | |||
| power, update_decay_epoch) | |||
| def test_learning_rate(): | |||
| lr = dr.piecewise_constant_lr(milestone, learning_rates) | |||
| assert len(lr) == milestone[-1] | |||
| def test_exponential_decay(): | |||
| lr1 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| assert len(lr1) == total_step | |||
| lr2 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True) | |||
| assert len(lr2) == total_step | |||
| def test_enatural_exp_decay(): | |||
| lr1 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| assert len(lr1) == total_step | |||
| lr2 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True) | |||
| assert len(lr2) == total_step | |||
| def test_inverse_decay(): | |||
| lr1 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch) | |||
| assert len(lr1) == total_step | |||
| lr2 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True) | |||
| assert len(lr2) == total_step | |||
| def test_cosine_decay(): | |||
| lr = dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) | |||
| assert len(lr) == total_step | |||
| def test_polynomial_decay(): | |||
| lr1 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| assert len(lr1) == total_step | |||
| lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, | |||
| True) | |||
| assert len(lr2) == total_step | |||