|
|
|
@@ -26,8 +26,6 @@ from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore._checkparam import Rel |
|
|
|
from .optimizer import Optimizer |
|
|
|
|
|
|
|
_learning_rate_update_func = ['linear', 'cos', 'sin'] |
|
|
|
|
|
|
|
adam_opt = C.MultitypeFuncGraph("adam_opt") |
|
|
|
|
|
|
|
|
|
|
|
@@ -94,10 +92,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): |
|
|
|
|
|
|
|
def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): |
|
|
|
"""Check the type of inputs.""" |
|
|
|
validator.check_float_positive('learning_rate', learning_rate, prim_name) |
|
|
|
validator.check_float_legal_value('learning_rate', learning_rate, prim_name) |
|
|
|
validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name) |
|
|
|
validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name) |
|
|
|
validator.check_value_type("learning_rate", learning_rate, [float], prim_name) |
|
|
|
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) |
|
|
|
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) |
|
|
|
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, prim_name) |
|
|
|
validator.check_float_positive('power', power, prim_name) |
|
|
|
validator.check_float_legal_value('power', power, prim_name) |
|
|
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) |
|
|
|
@@ -363,7 +361,7 @@ class AdamWeightDecayDynamicLR(Optimizer): |
|
|
|
eps=1e-6, |
|
|
|
weight_decay=0.0, |
|
|
|
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): |
|
|
|
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) |
|
|
|
super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) |
|
|
|
if self.is_group: |
|
|
|
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") |
|
|
|
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) |
|
|
|
|