|
|
|
@@ -360,16 +360,18 @@ class Optimizer(Cell): |
|
|
|
if len(ordered_parameters) != len(self.group_params): |
|
|
|
raise ValueError(f"The value of 'order_params' should be same with all group parameters.") |
|
|
|
|
|
|
|
ordered_params = [None] * params_length |
|
|
|
ordered_learning_rate = [None] * params_length |
|
|
|
ordered_weight_decay = [None] * params_length |
|
|
|
params_name = [param.name for param in ordered_parameters] |
|
|
|
|
|
|
|
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay): |
|
|
|
index = params_name.index(param.name) |
|
|
|
ordered_params[index] = param |
|
|
|
ordered_learning_rate[index] = lr |
|
|
|
ordered_weight_decay[index] = wd |
|
|
|
|
|
|
|
self.group_params = list(ordered_parameters) |
|
|
|
self.group_params = ordered_params |
|
|
|
self.group_lr = ordered_learning_rate |
|
|
|
self.group_weight_decay = ordered_weight_decay |
|
|
|
|
|
|
|
|