|
|
|
@@ -13,22 +13,18 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""lars optimizer""" |
|
|
|
from typing import Iterable |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.common import Tensor |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore.common import Tensor, Parameter, dtype as mstype |
|
|
|
from .optimizer import _grad_scale, Optimizer |
|
|
|
|
|
|
|
_lars_opt = C.MultitypeFuncGraph("lars_opt") |
|
|
|
|
|
|
|
|
|
|
|
@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Bool", "Bool") |
|
|
|
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): |
|
|
|
@_lars_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Bool", "Bool") |
|
|
|
def _tensor_run_opt(lars, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag): |
|
|
|
"""Apply lars optimizer to the weight parameter.""" |
|
|
|
if lars_flag: |
|
|
|
op_reduce_sum = P.SquareSumAll() |
|
|
|
@@ -42,10 +38,12 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f |
|
|
|
|
|
|
|
return gradient |
|
|
|
|
|
|
|
def _check_param_value(optimizer, epsilon, hyperpara, use_clip, prim_name): |
|
|
|
def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): |
|
|
|
validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) |
|
|
|
if "Adam" in optimizer.cls_name or "Lamb" in optimizer.cls_name: |
|
|
|
raise TypeError("LARS can not be used with ", optimizer.cls_name) |
|
|
|
validator.check_value_type("epsilon", epsilon, [float], prim_name) |
|
|
|
validator.check_value_type("hyperpara", hyperpara, [float], prim_name) |
|
|
|
validator.check_value_type("coefficient", coefficient, [float], prim_name) |
|
|
|
validator.check_value_type("use_clip", use_clip, [bool], prim_name) |
|
|
|
|
|
|
|
class LARS(Optimizer): |
|
|
|
@@ -58,14 +56,10 @@ class LARS(Optimizer): |
|
|
|
Args: |
|
|
|
optimizer (Optimizer): MindSpore optimizer for which to wrap and modify gradients. |
|
|
|
epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05. |
|
|
|
hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001. |
|
|
|
weight_decay (float): Weight decay (L2 penalty). It should be equal to or greater than 0. Default: 0.0. |
|
|
|
coefficient (float): Trust coefficient for calculating the local learning rate. Default: 0.001. |
|
|
|
use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False. |
|
|
|
decay_filter (Function): A function to determine whether apply weight decay on parameters. Default: |
|
|
|
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. |
|
|
|
lars_filter (Function): A function to determine whether apply lars algorithm. Default: |
|
|
|
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. |
|
|
|
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. Default: 1.0. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **gradients** (tuple[Tensor]) - The gradients of `params` in optimizer, the shape is |
|
|
|
@@ -78,51 +72,54 @@ class LARS(Optimizer): |
|
|
|
>>> net = Net() |
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
>>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9) |
|
|
|
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, hyperpara=0.02) |
|
|
|
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02) |
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, |
|
|
|
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, |
|
|
|
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): |
|
|
|
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")], weight_decay, loss_scale) |
|
|
|
if optimizer.is_group: |
|
|
|
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") |
|
|
|
_check_param_value(optimizer, epsilon, hyperpara, use_clip, self.cls_name) |
|
|
|
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, |
|
|
|
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): |
|
|
|
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) |
|
|
|
_check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) |
|
|
|
self.opt = optimizer |
|
|
|
self.parameters = optimizer.parameters |
|
|
|
self.learning_rate = optimizer.learning_rate |
|
|
|
self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip) |
|
|
|
self.reciprocal_scale = 1.0 / loss_scale |
|
|
|
self.weight_decay = weight_decay |
|
|
|
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) |
|
|
|
self.cast = P.Cast() |
|
|
|
self.decay_flag = tuple(decay_filter(x) for x in self.parameters) |
|
|
|
self.parameters = optimizer.parameters |
|
|
|
if use_clip is True: |
|
|
|
self.learning_rate = optimizer.learning_rate |
|
|
|
self.dynamic_lr = optimizer.dynamic_lr |
|
|
|
self.gather = optimizer.gather |
|
|
|
self.assignadd = optimizer.assignadd |
|
|
|
self.global_step = optimizer.global_step |
|
|
|
else: |
|
|
|
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") |
|
|
|
self.reciprocal_scale = optimizer.reciprocal_scale |
|
|
|
optimizer.reciprocal_scale = 1.0 |
|
|
|
self.is_group = optimizer.is_group |
|
|
|
if self.is_group: |
|
|
|
self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) |
|
|
|
else: |
|
|
|
self.weight_decay = optimizer.weight_decay / optimizer.loss_scale |
|
|
|
optimizer.exec_weight_decay = False |
|
|
|
optimizer.weight_decay = 0.0 |
|
|
|
self.decay_flags = optimizer.decay_flags |
|
|
|
self.lars_flag = tuple(lars_filter(x) for x in self.parameters) |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.dynamic_lr = False |
|
|
|
self.gather = None |
|
|
|
self.global_step = None |
|
|
|
self.axis = None |
|
|
|
if isinstance(self.learning_rate.default_input, Iterable) or \ |
|
|
|
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1): |
|
|
|
self.dynamic_lr = True |
|
|
|
self.assignadd = P.AssignAdd() |
|
|
|
self.gather = P.GatherV2() |
|
|
|
self.global_step = Parameter(initializer(0, [1], mstype.int32), name="lars_global_step") |
|
|
|
self.axis = 0 |
|
|
|
|
|
|
|
def construct(self, gradients): |
|
|
|
params = self.parameters |
|
|
|
if self.dynamic_lr: |
|
|
|
lr = self.gather(self.learning_rate, self.global_step, self.axis) |
|
|
|
lr = self.gather(self.learning_rate, self.global_step, 0) |
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1)) |
|
|
|
else: |
|
|
|
lr = self.learning_rate |
|
|
|
if self.reciprocal_scale != 1.0: |
|
|
|
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) |
|
|
|
|
|
|
|
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, self.weight_decay, lr), |
|
|
|
gradients, params, self.decay_flag, self.lars_flag) |
|
|
|
if self.is_group: |
|
|
|
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr), self.weight_decay, |
|
|
|
gradients, params, self.decay_flags, self.lars_flag) |
|
|
|
else: |
|
|
|
grad_t = self.hyper_map(F.partial(_lars_opt, self.lars, lr, self.weight_decay), |
|
|
|
gradients, params, self.decay_flags, self.lars_flag) |
|
|
|
success = self.opt(grad_t) |
|
|
|
|
|
|
|
return success |