Browse Source

enable use float type learning rate in lars optimizer

tags/v0.3.0-alpha
Ziyan chang zherui 5 years ago
parent
commit
ada46fc219
2 changed files with 24 additions and 4 deletions
  1. +6
    -3
      mindspore/nn/optim/lars.py
  2. +18
    -1
      tests/ut/python/nn/optim/test_lars.py

+ 6
- 3
mindspore/nn/optim/lars.py View File

@@ -13,12 +13,14 @@
# limitations under the License.
# ============================================================================
"""lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell
from .optimizer import grad_scale

@@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None
self.global_step = None
self.axis = None
if not isinstance(self.learning_rate, float):
if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True
self.assignadd = P.AssignAdd()
self.gather = P.GatherV2()
@@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = F.scalar_to_array(self.learning_rate)
lr = self.learning_rate
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)



+ 18
- 1
tests/ut/python/nn/optim/test_lars.py View File

@@ -46,7 +46,7 @@ class Net(nn.Cell):
return x


def test_lars():
def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
@@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)


def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()

lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)

net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)

Loading…
Cancel
Save