Browse Source

!1555 fix bug in lamb warmup step check

Merge pull request !1555 from wangnan39/fix_bug_in_check_lamb_warmup_step
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
2a6a3e012c
2 changed files with 13 additions and 15 deletions
  1. +1
    -2
      mindspore/nn/optim/lamb.py
  2. +12
    -13
      tests/ut/python/nn/optim/test_lamb.py

+ 1
- 2
mindspore/nn/optim/lamb.py View File

@@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
@@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name) validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)


+ 12
- 13
tests/ut/python/nn/optim/test_lamb.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" test lamb """ """ test lamb """
import numpy as np import numpy as np
import pytest


import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
@@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x return x




def test_lamb_1():
""" test_Lamb_1 """
def test_lamb_compile():
""" test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
optimizer = Lamb(net.trainable_params(), decay_steps=10)


net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)




def test_lamb_2():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
def test_lamb_error():
net = Net() net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)


net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=1.0)

with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)

Loading…
Cancel
Save