|
|
|
@@ -81,7 +81,7 @@ class Optimizer(Cell): |
|
|
|
raise ValueError("Optimizer got an empty parameter list.") |
|
|
|
|
|
|
|
if not isinstance(parameters[0], (dict, Parameter)): |
|
|
|
raise ValueError("Only a list of Parameter or dict can be supported.") |
|
|
|
raise TypeError("Only a list of Parameter or dict can be supported.") |
|
|
|
|
|
|
|
if isinstance(loss_scale, int): |
|
|
|
loss_scale = float(loss_scale) |
|
|
|
@@ -258,9 +258,9 @@ class Optimizer(Cell): |
|
|
|
|
|
|
|
for param in group_param['params']: |
|
|
|
validator.check_value_type("parameter", param, [Parameter], self.cls_name) |
|
|
|
if param in params_store: |
|
|
|
if param.name in params_store: |
|
|
|
raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") |
|
|
|
params_store.append(param) |
|
|
|
params_store.append(param.name) |
|
|
|
self.group_lr.append(Parameter(lr, name="lr_" + param.name)) |
|
|
|
self.group_weight_decay.append(weight_decay_) |
|
|
|
|
|
|
|
@@ -298,18 +298,22 @@ class Optimizer(Cell): |
|
|
|
Parameter, single `Parameter` or `list[Parameter]` according to the input type. |
|
|
|
""" |
|
|
|
if not isinstance(param, (Parameter, list)): |
|
|
|
raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.") |
|
|
|
raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") |
|
|
|
|
|
|
|
if isinstance(param, list): |
|
|
|
lr = [] |
|
|
|
for p in param: |
|
|
|
validator.check_value_type("parameter", p, [Parameter], self.cls_name) |
|
|
|
if p not in self.parameters: |
|
|
|
raise ValueError(f"The parameter {p.name} is not in optimizer.") |
|
|
|
if self.is_group_lr: |
|
|
|
index = self.parameters.index(p) |
|
|
|
lr.append(self.learning_rate[index]) |
|
|
|
else: |
|
|
|
lr.append(self.learning_rate) |
|
|
|
else: |
|
|
|
if param not in self.parameters: |
|
|
|
raise ValueError(f"The parameter {param.name} is not in optimizer.") |
|
|
|
if self.is_group_lr: |
|
|
|
index = self.parameters.index(param) |
|
|
|
lr = self.learning_rate[index] |
|
|
|
|