Browse Source

!45 fix lars base class type

Merge pull request !45 from gziyan/fix_lars_incubator
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
b6fe48203a
2 changed files with 4 additions and 5 deletions
  1. +3
    -4
      mindspore/nn/optim/lars.py
  2. +1
    -1
      mindspore/nn/optim/optimizer.py

+ 3
- 4
mindspore/nn/optim/lars.py View File

@@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.nn.cell import Cell
from .optimizer import grad_scale
from .optimizer import grad_scale, Optimizer

lars_opt = C.MultitypeFuncGraph("lars_opt")

@@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca
return gradient


class LARS(Cell):
class LARS(Optimizer):
"""
Implements the LARS algorithm with LARSUpdate Operator.

@@ -98,7 +97,7 @@ class LARS(Cell):
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False,
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
super(LARS, self).__init__(auto_prefix=False)
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")])
self.opt = optimizer
self.parameters = optimizer.parameters
self.learning_rate = optimizer.learning_rate


+ 1
- 1
mindspore/nn/optim/optimizer.py View File

@@ -51,7 +51,7 @@ class Optimizer(Cell):
"""

def __init__(self, learning_rate, parameters):
super(Optimizer, self).__init__()
super(Optimizer, self).__init__(auto_prefix=False)
if isinstance(learning_rate, float):
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
elif isinstance(learning_rate, Iterable):


Loading…
Cancel
Save