| @@ -178,20 +178,18 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): | |||||
| net.trainable_params())) | net.trainable_params())) | ||||
| no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] | ||||
| group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | ||||
| {'params': no_decayed_params}, | |||||
| {'params': no_decayed_params, 'weight_decay': 0.0}, | |||||
| {'order_params': net.trainable_params()}] | {'order_params': net.trainable_params()}] | ||||
| if config.use_lars: | if config.use_lars: | ||||
| momentum = nn.Momentum(group_params, lr, config.momentum, | momentum = nn.Momentum(group_params, lr, config.momentum, | ||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale, | |||||
| use_nesterov=config.use_nesterov) | |||||
| loss_scale=config.loss_scale, use_nesterov=config.use_nesterov) | |||||
| opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient, | opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient, | ||||
| lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name) | lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name) | ||||
| else: | else: | ||||
| opt = nn.Momentum(group_params, lr, config.momentum, | opt = nn.Momentum(group_params, lr, config.momentum, | ||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale, | |||||
| use_nesterov=config.use_nesterov) | |||||
| loss_scale=config.loss_scale, use_nesterov=config.use_nesterov) | |||||
| # model | # model | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, | model = Model(net, loss_fn=loss, optimizer=opt, | ||||