Browse Source

add dynamic lr and enhance optim

tags/v0.2.0-alpha
root 5 years ago
parent
commit
7d700295f8
8 changed files with 650 additions and 134 deletions
  1. +300
    -0
      mindspore/nn/dynamic_lr.py
  2. +5
    -23
      mindspore/nn/optim/adam.py
  3. +5
    -33
      mindspore/nn/optim/momentum.py
  4. +90
    -12
      mindspore/nn/optim/optimizer.py
  5. +5
    -27
      mindspore/nn/optim/rmsprop.py
  6. +8
    -32
      mindspore/nn/optim/sgd.py
  7. +3
    -7
      tests/ut/python/nn/optim/test_optimizer.py
  8. +234
    -0
      tests/ut/python/nn/test_dynamic_lr.py

+ 300
- 0
mindspore/nn/dynamic_lr.py View File

@@ -0,0 +1,300 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dynamic learning rate"""
import math

from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel


def piecewise_constant_lr(milestone, learning_rates):
r"""
Get piecewise constant learning rate.

Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be
:math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of
`milestone`. Let the output learning rate be `y`.

.. math::
y[i] = x_t for i \in [M_{t-1}, M_t)

Args:
milestone (list[int]): A list of milestone. This list is a monotone increasing list.
learning_rates (list[float]): A list of learning rates.

Returns:
list[float]. The size of list is :math:`M_N`.

Examples:
>>> milestone = [2, 5, 10]
>>> learning_rates = [0.1, 0.05, 0.01]
>>> lr = piecewise_constant_lr(milestone, learning_rates)
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
"""
validator.check_type('milestone', milestone, (tuple, list))
validator.check_type('learning_rates', learning_rates, (tuple, list))
if len(milestone) != len(learning_rates):
raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')

lr = []
last_item = 0
for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT)
validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float])
if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item)
last_item = item

return lr


def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
validator.check_float_positive('learning_rate', learning_rate)
validator.check_float_positive('decay_rate', decay_rate)
validator.check_type('is_stair', is_stair, [bool])


def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
r"""
Calculate learning rate base on exponential decay function.

For the i-th step, the formula of computing decayed_learning_rate[i] is:

.. math::
decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}}

Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.

Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.

Returns:
list[float]. The size of list is `total_step`.

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 1
>>> lr = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]
"""
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)

lr = []
for i in range(total_step):
if is_stair:
lr.append(learning_rate * decay_rate ** math.floor(math.floor(i / step_per_epoch) / decay_epoch))
else:
lr.append(learning_rate * decay_rate ** (math.floor(i / step_per_epoch) / decay_epoch))
return lr


def natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
r"""
Calculate learning rate base on natural exponential decay function.

For the i-th step, the formula of computing decayed_learning_rate[i] is:

.. math::
decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch}

Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.

Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.

Returns:
list[float]. The size of list is `total_step`.

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.9
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> lr = natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
[0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657]
"""
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)

function = lambda x, y: x
if is_stair:
function = lambda x, y: math.floor(x / y) * y

lr = []
for i in range(total_step):
lr.append(learning_rate * math.e ** (-decay_rate * function(math.floor(i / step_per_epoch), decay_epoch)))
return lr


def inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
r"""
Calculate learning rate base on inverse-time decay function.

For the i-th step, the formula of computing decayed_learning_rate[i] is:

.. math::
decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch)

Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.

Args:
learning_rate (float): The initial value of learning rate.
decay_rate (float): The decay rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
is_stair (bool): If true, learning rate decay once every `decay_epoch` times. Default: False.

Returns:
list[float]. The size of list is `total_step`.

Examples:
>>> learning_rate = 0.1
>>> decay_rate = 0.5
>>> total_step = 6
>>> step_per_epoch = 1
>>> decay_epoch = 1
>>> lr = inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
[0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574]
"""
_check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)

lr = []
for i in range(total_step):
if is_stair:
lr.append(learning_rate / (1 + decay_rate * math.floor(math.floor(i / step_per_epoch) / decay_epoch)))
else:
lr.append(learning_rate / (1 + decay_rate * math.floor(i / step_per_epoch) / decay_epoch))
return lr


def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
r"""
Calculate learning rate base on cosine decay function.

For the i-th step, the formula of computing decayed_learning_rate[i] is:

.. math::
decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
(1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))

Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.

Args:
min_lr (float): The minimum value of learning rate.
max_lr (float): The maximum value of learning rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.

Returns:
list[float]. The size of list is `total_step`.

Examples:
>>> min_lr = 0.01
>>> max_lr = 0.1
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator.check_float_positive('min_lr', min_lr)
validator.check_float_positive('max_lr', max_lr)
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)

delta = 0.5 * (max_lr - min_lr)
lr = []
for i in range(total_step):
tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch)
lr.append(min_lr + delta * (1 + math.cos(math.pi * tmp_epoch / decay_epoch)))
return lr


def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
update_decay_epoch=False):
r"""
Calculate learning rate base on polynomial decay function.

For the i-th step, the formula of computing decayed_learning_rate[i] is:

.. math::
decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
(1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate

Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is
:math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)`

Args:
learning_rate (float): The initial value of learning rate.
end_learning_rate (float): The end value of learning rate.
total_step (int): The total number of steps.
step_per_epoch (int): The number of steps in per epoch.
decay_epoch (int): A value used to calculate decayed learning rate.
power (float): A value used to calculate decayed learning rate.
update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.

Returns:
list[float]. The size of list is `total_step`.

Examples:
>>> learning_rate = 0.1
>>> end_learning_rate = 0.01
>>> total_step = 6
>>> step_per_epoch = 2
>>> decay_epoch = 2
>>> power = 0.5
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_float_positive('learning_rate', learning_rate)
validator.check_float_positive('end_learning_rate', end_learning_rate)
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
validator.check_type('power', power, [float])
validator.check_type('update_decay_epoch', update_decay_epoch, [bool])

function = lambda x, y: (x, min(x, y))
if update_decay_epoch:
function = lambda x, y: (x * max(math.ceil(y / x), 1), y)

lr = []
delta = learning_rate - end_learning_rate
for i in range(total_step):
current_epoch = math.floor(i / step_per_epoch)
decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
return lr

+ 5
- 23
mindspore/nn/optim/adam.py View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""adam""" """adam"""
from typing import Iterable
import numpy as np import numpy as np


from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@@ -25,7 +24,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale
from .optimizer import Optimizer


_learning_rate_update_func = ['linear', 'cos', 'sin'] _learning_rate_update_func = ['linear', 'cos', 'sin']


@@ -168,22 +167,13 @@ class Adam(Optimizer):
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0, use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params)
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
_check_param_value(beta1, beta2, eps, weight_decay) _check_param_value(beta1, beta2, eps, weight_decay)
validator.check_type("use_locking", use_locking, [bool]) validator.check_type("use_locking", use_locking, [bool])
validator.check_type("use_nesterov", use_nesterov, [bool]) validator.check_type("use_nesterov", use_nesterov, [bool])
validator.check_type("loss_scale", loss_scale, [float]) validator.check_type("loss_scale", loss_scale, [float])
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT)


self.dynamic_lr = False
if isinstance(learning_rate, Iterable) or \
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0

self.beta1 = Tensor(beta1, mstype.float32) self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32)
self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
@@ -196,8 +186,6 @@ class Adam(Optimizer):
self.decay_tf = tuple(decay_filter(x) for x in self.parameters) self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.Adam(use_locking, use_nesterov) self.opt = P.Adam(use_locking, use_nesterov)
self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale


self.pow = P.Pow() self.pow = P.Pow()
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
@@ -208,15 +196,9 @@ class Adam(Optimizer):
params = self.parameters params = self.parameters
moment1 = self.moment1 moment1 = self.moment1
moment2 = self.moment2 moment2 = self.moment2
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

lr = self.learning_rate
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()


beta1_power = self.beta1_power * self.beta1 beta1_power = self.beta1_power * self.beta1
self.beta1_power = beta1_power self.beta1_power = beta1_power


+ 5
- 33
mindspore/nn/optim/momentum.py View File

@@ -13,14 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""momentum""" """momentum"""
from typing import Iterable

from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.common import Tensor
from .optimizer import Optimizer, apply_decay, grad_scale
from .optimizer import Optimizer


momentum_opt = C.MultitypeFuncGraph("momentum_opt") momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@@ -88,43 +83,20 @@ class Momentum(Optimizer):
""" """
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Momentum, self).__init__(learning_rate, params)
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
if isinstance(learning_rate, Iterable) or \
(isinstance(learning_rate, Tensor) and learning_rate.dim() == 1):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
else:
self.dynamic_lr = False
self.gather = None
self.assignadd = None
self.global_step = None
self.axis = None
self.momentum = Parameter(momentum, name="momentum") self.momentum = Parameter(momentum, name="momentum")
self.params = self.parameters self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum() self.opt = P.ApplyMomentum()
self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale
self.one = Tensor(1, mstype.int32)


def construct(self, gradients): def construct(self, gradients):
params = self.params params = self.params
moments = self.moments moments = self.moments
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
return success return success

+ 90
- 12
mindspore/nn/optim/optimizer.py View File

@@ -17,9 +17,11 @@ from typing import Iterable


import numpy as np import numpy as np


import mindspore
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@@ -42,34 +44,110 @@ class Optimizer(Cell):
Args: Args:
learning_rate (float): A floating point value for the learning rate. Should be greater than 0. learning_rate (float): A floating point value for the learning rate. Should be greater than 0.
parameters (list): A list of parameter, which will be updated. The element in `parameters` parameters (list): A list of parameter, which will be updated. The element in `parameters`
should be class mindspore.Parameter.
should be class mindspore.Parameter.
weight_decay (float): A floating point value for the weight decay. Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda
x: 'beta' not in x.name and 'gamma' not in x.name.


Raises: Raises:
ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1. ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1.
TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable. TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable.
""" """


def __init__(self, learning_rate, parameters):
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Optimizer, self).__init__() super(Optimizer, self).__init__()
if isinstance(learning_rate, float): if isinstance(learning_rate, float):
self.dynamic_lr = False
self.gather = None
self.assignadd = None
self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
elif isinstance(learning_rate, Iterable):
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
elif isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
f"but got {learning_rate.dim()}.")
else: else:
raise TypeError("Learning rate should be float, Tensor or Iterable.")
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
if isinstance(learning_rate, Iterable):
learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32))
elif isinstance(learning_rate, Tensor):
if learning_rate.dim() > 1:
raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`,"
f"but got {learning_rate.dim()}.")
if learning_rate.dim() == 1 and learning_rate.size() < 2:
logger.warning("If want to use the dynamic learning rate, please make sure that the number "
"of elements in the list, tuple or tensor passed is greater than 1.")
else:
raise TypeError("Learning rate should be float, Tensor or Iterable.")

if loss_scale <= 0.0:
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
if weight_decay < 0.0:
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))


if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1 and learning_rate.size() < 2:
logger.warning("If want to use the dynamic learning rate, please make sure that "
"the number of elements in the list, tuple or tensor passed is greater than 1.")
self.learning_rate = Parameter(learning_rate, name="learning_rate") self.learning_rate = Parameter(learning_rate, name="learning_rate")
self.parameters = ParameterTuple(parameters) self.parameters = ParameterTuple(parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)

if not self.parameters: if not self.parameters:
raise ValueError("optimizer got an empty parameter list.") raise ValueError("optimizer got an empty parameter list.")


def decay_weight(self, gradients):
"""
Weight decay.

An approach to reduce the overfitting of a deep learning neural network model.

Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.

Returns:
tuple[Tensor], The gradients after weight decay.
"""
if self.weight_decay > 0:
params = self.params
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)

return gradients

def scale_grad(self, gradients):
"""
Loss scale for mixed precision.

An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
network.

Args:
gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape with
`self.parameters`.

Returns:
tuple[Tensor], The gradients after loss scale.

"""
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

return gradients

def get_lr(self):
"""
Get the learning rate of current step.

Returns:
float, the learning rate of current step.
"""
lr = self.learning_rate
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, 0)
F.control_depend(lr, self.assignadd(self.global_step, 1))

return lr

def construct(self, *hyper_params): def construct(self, *hyper_params):
raise NotImplementedError raise NotImplementedError




+ 5
- 27
mindspore/nn/optim/rmsprop.py View File

@@ -14,12 +14,8 @@
# ============================================================================ # ============================================================================
"""rmsprop""" """rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from mindspore.common import Tensor
from .optimizer import Optimizer, grad_scale, apply_decay
from .optimizer import Optimizer


rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
@@ -138,7 +134,7 @@ class RMSProp(Optimizer):
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,
use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(RMSProp, self).__init__(learning_rate, params)
super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)


if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
@@ -157,15 +153,6 @@ class RMSProp(Optimizer):
else: else:
self.opt = P.ApplyRMSProp(use_locking) self.opt = P.ApplyRMSProp(use_locking)


self.dynamic_lr = False
if not isinstance(learning_rate, float):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.one = Tensor(1, mstype.int32)

self.momentum = momentum self.momentum = momentum


self.ms = self.parameters.clone(prefix="mean_square", init='zeros') self.ms = self.parameters.clone(prefix="mean_square", init='zeros')
@@ -173,21 +160,12 @@ class RMSProp(Optimizer):
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()


self.decay = decay self.decay = decay
self.decay_tf = tuple(decay_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.weight_decay = weight_decay * loss_scale


def construct(self, gradients): def construct(self, gradients):
params = self.parameters params = self.parameters
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients)
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, self.one))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.centered: if self.centered:
success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon,
self.momentum), params, self.mg, self.ms, self.moment, gradients) self.momentum), params, self.mg, self.ms, self.moment, gradients)


+ 8
- 32
mindspore/nn/optim/sgd.py View File

@@ -14,11 +14,9 @@
# ============================================================================ # ============================================================================
"""sgd""" """sgd"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import ParamValidator as validator
import mindspore.common.dtype as mstype
from .optimizer import Optimizer, grad_scale
from .optimizer import Optimizer


sgd_opt = C.MultitypeFuncGraph("sgd_opt") sgd_opt = C.MultitypeFuncGraph("sgd_opt")


@@ -83,7 +81,7 @@ class SGD(Optimizer):
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,
loss_scale=1.0): loss_scale=1.0):


super(SGD, self).__init__(learning_rate, params)
super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)


if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
@@ -92,44 +90,22 @@ class SGD(Optimizer):
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening self.dampening = dampening


if weight_decay < 0.0:
raise ValueError("weight_decay should be at least 0.0, but got weight_decay {}".format(weight_decay))
self.weight_decay = weight_decay

validator.check_type("nesterov", nesterov, [bool]) validator.check_type("nesterov", nesterov, [bool])
self.nesterov = nesterov self.nesterov = nesterov


self.opt = P.SGD(dampening, weight_decay, nesterov) self.opt = P.SGD(dampening, weight_decay, nesterov)


self.dynamic_lr = False
self.gather = None
self.global_step = None
self.axis = None
if not isinstance(learning_rate, float):
self.dynamic_lr = True
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="global_step")
self.axis = 0
self.momentum = Parameter(momentum, name="momentum") self.momentum = Parameter(momentum, name="momentum")
self.params = self.parameters
self.accum = self.params.clone(prefix="accum", init='zeros')
self.stat = self.params.clone(prefix="stat", init='ones')
self.accum = self.parameters.clone(prefix="accum", init='zeros')
self.stat = self.parameters.clone(prefix="stat", init='ones')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()


self.weight_decay = weight_decay * loss_scale
self.reciprocal_scale = 1.0 / loss_scale

def construct(self, gradients): def construct(self, gradients):
params = self.params
params = self.parameters
accum = self.accum accum = self.accum
stat = self.stat stat = self.stat
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
if self.dynamic_lr:
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = self.learning_rate
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat)
return success return success

+ 3
- 7
tests/ut/python/nn/optim/test_optimizer.py View File

@@ -15,17 +15,11 @@
""" test optimizer """ """ test optimizer """
import numpy as np import numpy as np
import pytest import pytest
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter




gradient = Tensor(np.zeros([1, 2, 3]))
accumulation = gradient
variable = accumulation


paramsTensor = Tensor(np.zeros([1, 2, 3]))
class IterableObjc: class IterableObjc:
def __iter__(self): def __iter__(self):
cont = 0 cont = 0
@@ -56,6 +50,7 @@ class TestAdam():


def test_construct(self): def test_construct(self):
with pytest.raises(TypeError): with pytest.raises(TypeError):
gradient = Tensor(np.zeros([1, 2, 3]))
adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
use_nesterov=False, weight_decay=0.0, loss_scale=1.0) use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
adam.construct(gradient) adam.construct(gradient)
@@ -105,4 +100,5 @@ class TestUnsupportParam():


def test_Sgd_init(self): def test_Sgd_init(self):
with pytest.raises(TypeError): with pytest.raises(TypeError):
paramsTensor = Tensor(np.zeros([1, 2, 3]))
SGD(paramsTensor) SGD(paramsTensor)

+ 234
- 0
tests/ut/python/nn/test_dynamic_lr.py View File

@@ -0,0 +1,234 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Test Dynamic Learning Rate """
import pytest
import mindspore
from mindspore.nn import dynamic_lr as dr

milestone = [10, 20, 30]
learning_rates = [0.1, 0.05, 0.01]
learning_rate = 0.1
end_learning_rate = 0.01
decay_rate = 0.9
total_step = 30
step_per_epoch = 3
decay_epoch = 2
min_lr = 0.01
max_lr = 0.1
power = 0.5

class TestInputs:
def test_milestone1(self):
milestone1 = 1
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone1, learning_rates)

def test_milestone2(self):
milestone1 = [20, 10, 1]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone1, learning_rates)

milestone2 = [1.0, 2.0, True]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone2, learning_rates)

def test_learning_rates1(self):
lr = True
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone, lr)

def test_learning_rates2(self):
lr = [1, 2, 1]
with pytest.raises(ValueError):
dr.piecewise_constant_lr(milestone, lr)

def test_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
with pytest.raises(TypeError):
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)

def test_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(lr, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)

def test_end_learning_rate_type(self):
lr = True
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)

def test_end_learning_rate_value(self):
lr = -1.0
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, lr, total_step, step_per_epoch, decay_epoch, power)

def test_decay_rate_type(self):
rate = 'a'
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)

def test_decay_rate_value(self):
rate = -1.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, rate, total_step, step_per_epoch, decay_epoch)

def test_total_step1(self):
total_step1 = 2.0
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)

with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)

def test_total_step2(self):
total_step1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)

with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)

def test_step_per_epoch1(self):
step_per_epoch1 = True
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)

with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)

def test_step_per_epoch2(self):
step_per_epoch1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)

def test_decay_epoch1(self):
decay_epoch1 = 'm'
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)

with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)

def test_decay_epoch2(self):
decay_epoch1 = -1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)

with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)

with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)

def test_is_stair(self):
is_stair = 1
with pytest.raises(ValueError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)

def test_min_lr_type(self):
min_lr1 = True
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)

def test_min_lr_value(self):
min_lr1 = -1.0
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr1, max_lr, total_step, step_per_epoch, decay_epoch)

def test_max_lr_type(self):
max_lr1 = 'a'
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)

def test_max_lr_value(self):
max_lr1 = -1.0
with pytest.raises(ValueError):
dr.cosine_decay_lr(min_lr, max_lr1, total_step, step_per_epoch, decay_epoch)

def test_power(self):
power1 = True
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)

def test_update_decay_epoch(self):
update_decay_epoch = 1
with pytest.raises(ValueError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
power, update_decay_epoch)


def test_learning_rate():
lr = dr.piecewise_constant_lr(milestone, learning_rates)
assert len(lr) == milestone[-1]


def test_exponential_decay():
lr1 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step

lr2 = dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step


def test_enatural_exp_decay():
lr1 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step

lr2 = dr.natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step


def test_inverse_decay():
lr1 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
assert len(lr1) == total_step

lr2 = dr.inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
assert len(lr2) == total_step


def test_cosine_decay():
lr = dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
assert len(lr) == total_step

def test_polynomial_decay():
lr1 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
assert len(lr1) == total_step
lr2 = dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
True)
assert len(lr2) == total_step

Loading…
Cancel
Save