|
|
|
@@ -134,10 +134,6 @@ class SGD(Optimizer): |
|
|
|
if isinstance(momentum, float) and momentum < 0.0: |
|
|
|
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) |
|
|
|
|
|
|
|
if nesterov and (momentum <= 0.0 or dampening != 0.0): |
|
|
|
raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," |
|
|
|
"but got momentum {}, dampening {}".format(momentum, dampening)) |
|
|
|
|
|
|
|
if isinstance(dampening, int): |
|
|
|
dampening = float(dampening) |
|
|
|
if not isinstance(dampening, float): |
|
|
|
@@ -151,6 +147,10 @@ class SGD(Optimizer): |
|
|
|
weight_decay = float(weight_decay) |
|
|
|
|
|
|
|
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) |
|
|
|
|
|
|
|
if nesterov and (momentum <= 0.0 or dampening != 0.0): |
|
|
|
raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," |
|
|
|
"but got momentum {}, dampening {}".format(momentum, dampening)) |
|
|
|
self.nesterov = nesterov |
|
|
|
|
|
|
|
self.opt = P.SGD(dampening, weight_decay, nesterov) |
|
|
|
|