| @@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para | |||||
| def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | ||||
| end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): | end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): | ||||
| """Check the type of inputs.""" | """Check the type of inputs.""" | ||||
| _ = warmup_steps | |||||
| validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) | validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) | ||||
| validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) | validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) | ||||
| validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) | validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) | ||||
| @@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | |||||
| validator.check_float_positive('power', power, prim_name) | validator.check_float_positive('power', power, prim_name) | ||||
| validator.check_float_legal_value('power', power, prim_name) | validator.check_float_legal_value('power', power, prim_name) | ||||
| validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | ||||
| validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name) | |||||
| validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name) | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | validator.check_value_type("beta1", beta1, [float], prim_name) | ||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | validator.check_value_type("beta2", beta2, [float], prim_name) | ||||
| validator.check_value_type("eps", eps, [float], prim_name) | validator.check_value_type("eps", eps, [float], prim_name) | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test lamb """ | """ test lamb """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor, Parameter | from mindspore import Tensor, Parameter | ||||
| @@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell): | |||||
| return x | return x | ||||
| def test_lamb_1(): | |||||
| """ test_Lamb_1 """ | |||||
| def test_lamb_compile(): | |||||
| """ test_Lamb_compile """ | |||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | ||||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | label = Tensor(np.zeros([1, 10]).astype(np.float32)) | ||||
| net = Net() | net = Net() | ||||
| net.set_train() | net.set_train() | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) | |||||
| optimizer = Lamb(net.trainable_params(), decay_steps=10) | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | train_network = TrainOneStepCell(net_with_loss, optimizer) | ||||
| _executor.compile(train_network, inputs, label) | _executor.compile(train_network, inputs, label) | ||||
| def test_lamb_2(): | |||||
| """ test_Lamb_2 """ | |||||
| inputs = Tensor(np.ones([1, 64]).astype(np.float32)) | |||||
| label = Tensor(np.zeros([1, 10]).astype(np.float32)) | |||||
| def test_lamb_error(): | |||||
| net = Net() | net = Net() | ||||
| net.set_train() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0) | |||||
| with pytest.raises(TypeError): | |||||
| Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||
| with pytest.raises(TypeError): | |||||
| Lamb(net.get_parameters(), decay_steps=1.0) | |||||
| with pytest.raises(ValueError): | |||||
| Lamb(net.get_parameters(), decay_steps=0) | |||||