|
|
|
@@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from .optimizer import grad_scale |
|
|
|
from .optimizer import grad_scale, Optimizer |
|
|
|
|
|
|
|
lars_opt = C.MultitypeFuncGraph("lars_opt") |
|
|
|
|
|
|
|
@@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
class LARS(Cell): |
|
|
|
class LARS(Optimizer): |
|
|
|
""" |
|
|
|
Implements the LARS algorithm with LARSUpdate Operator. |
|
|
|
|
|
|
|
@@ -98,7 +97,7 @@ class LARS(Cell): |
|
|
|
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, |
|
|
|
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, |
|
|
|
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): |
|
|
|
super(LARS, self).__init__(auto_prefix=False) |
|
|
|
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")]) |
|
|
|
self.opt = optimizer |
|
|
|
self.parameters = optimizer.parameters |
|
|
|
self.learning_rate = optimizer.learning_rate |
|
|
|
|