|
|
@@ -79,9 +79,10 @@ class Optimizer(Cell): |
|
|
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which |
|
|
the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which |
|
|
in the value of 'order_params' should be in one of group parameters. |
|
|
in the value of 'order_params' should be in one of group parameters. |
|
|
|
|
|
|
|
|
weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. |
|
|
|
|
|
|
|
|
weight_decay (float): A floating point value for the weight decay. It should be not less than 0 and not |
|
|
|
|
|
greater than 1. |
|
|
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. |
|
|
If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. |
|
|
loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the |
|
|
|
|
|
|
|
|
loss_scale (float): A floating point value for the loss scale. It should be not less than 1. If the |
|
|
type of `loss_scale` input is int, it will be converted to float. Default: 1.0. |
|
|
type of `loss_scale` input is int, it will be converted to float. Default: 1.0. |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
@@ -103,12 +104,12 @@ class Optimizer(Cell): |
|
|
if isinstance(loss_scale, int): |
|
|
if isinstance(loss_scale, int): |
|
|
loss_scale = float(loss_scale) |
|
|
loss_scale = float(loss_scale) |
|
|
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) |
|
|
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) |
|
|
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) |
|
|
|
|
|
|
|
|
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) |
|
|
|
|
|
|
|
|
if isinstance(weight_decay, int): |
|
|
if isinstance(weight_decay, int): |
|
|
weight_decay = float(weight_decay) |
|
|
weight_decay = float(weight_decay) |
|
|
validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name) |
|
|
validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name) |
|
|
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) |
|
|
|
|
|
|
|
|
validator.check_number_range("weight_decay", weight_decay, 0.0, 1.0, Rel.INC_BOTH, self.cls_name) |
|
|
|
|
|
|
|
|
self.is_group = False |
|
|
self.is_group = False |
|
|
self.is_group_lr = False |
|
|
self.is_group_lr = False |
|
|
|