|
|
|
@@ -13,12 +13,14 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""lars optimizer""" |
|
|
|
from typing import Iterable |
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
from mindspore.common import Tensor |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import composite as C |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from .optimizer import grad_scale |
|
|
|
|
|
|
|
@@ -111,7 +113,8 @@ class LARS(Cell): |
|
|
|
self.gather = None |
|
|
|
self.global_step = None |
|
|
|
self.axis = None |
|
|
|
if not isinstance(self.learning_rate, float): |
|
|
|
if isinstance(self.learning_rate.default_input, Iterable) or \ |
|
|
|
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1): |
|
|
|
self.dynamic_lr = True |
|
|
|
self.assignadd = P.AssignAdd() |
|
|
|
self.gather = P.GatherV2() |
|
|
|
@@ -124,7 +127,7 @@ class LARS(Cell): |
|
|
|
lr = self.gather(self.learning_rate, self.global_step, self.axis) |
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1)) |
|
|
|
else: |
|
|
|
lr = F.scalar_to_array(self.learning_rate) |
|
|
|
lr = self.learning_rate |
|
|
|
if self.reciprocal_scale != 1.0: |
|
|
|
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) |
|
|
|
|
|
|
|
|