|
|
|
@@ -127,11 +127,11 @@ if __name__ == '__main__': |
|
|
|
lr = Tensor(lr) |
|
|
|
|
|
|
|
# define opt |
|
|
|
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params())) |
|
|
|
no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params] |
|
|
|
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) |
|
|
|
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] |
|
|
|
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, |
|
|
|
{'params': no_decayed_params}, |
|
|
|
{'order_params': net.trainalbe_params()}] |
|
|
|
{'order_params': net.trainable_params()}] |
|
|
|
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) |
|
|
|
# define loss, model |
|
|
|
if target == "Ascend": |
|
|
|
|