|
|
|
@@ -212,3 +212,41 @@ def test_group_repeat_param(): |
|
|
|
{'params': no_conv_params}] |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
Adam(group_params, learning_rate=default_lr) |
|
|
|
|
|
|
|
|
|
|
|
def test_get_lr_parameter_with_group(): |
|
|
|
net = LeNet5() |
|
|
|
conv_lr = 0.1 |
|
|
|
default_lr = 0.3 |
|
|
|
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) |
|
|
|
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) |
|
|
|
group_params = [{'params': conv_params, 'lr': conv_lr}, |
|
|
|
{'params': no_conv_params, 'lr': default_lr}] |
|
|
|
opt = SGD(group_params) |
|
|
|
assert opt.is_group_lr is True |
|
|
|
for param in opt.parameters: |
|
|
|
lr = opt.get_lr_parameter(param) |
|
|
|
assert lr.name == 'lr_' + param.name |
|
|
|
|
|
|
|
lr_list = opt.get_lr_parameter(conv_params) |
|
|
|
for lr, param in zip(lr_list, conv_params): |
|
|
|
assert lr.name == 'lr_' + param.name |
|
|
|
|
|
|
|
|
|
|
|
def test_get_lr_parameter_with_no_group(): |
|
|
|
net = LeNet5() |
|
|
|
conv_weight_decay = 0.8 |
|
|
|
|
|
|
|
conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) |
|
|
|
no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) |
|
|
|
group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, |
|
|
|
{'params': no_conv_params}] |
|
|
|
opt = SGD(group_params) |
|
|
|
assert opt.is_group_lr is False |
|
|
|
for param in opt.parameters: |
|
|
|
lr = opt.get_lr_parameter(param) |
|
|
|
assert lr.name == opt.learning_rate.name |
|
|
|
|
|
|
|
params_error = [1, 2, 3] |
|
|
|
with pytest.raises(TypeError): |
|
|
|
opt.get_lr_parameter(params_error) |