| @@ -134,10 +134,6 @@ class SGD(Optimizer): | |||||
| if isinstance(momentum, float) and momentum < 0.0: | if isinstance(momentum, float) and momentum < 0.0: | ||||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | ||||
| if nesterov and (momentum <= 0.0 or dampening != 0.0): | |||||
| raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," | |||||
| "but got momentum {}, dampening {}".format(momentum, dampening)) | |||||
| if isinstance(dampening, int): | if isinstance(dampening, int): | ||||
| dampening = float(dampening) | dampening = float(dampening) | ||||
| if not isinstance(dampening, float): | if not isinstance(dampening, float): | ||||
| @@ -151,6 +147,10 @@ class SGD(Optimizer): | |||||
| weight_decay = float(weight_decay) | weight_decay = float(weight_decay) | ||||
| validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) | validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) | ||||
| if nesterov and (momentum <= 0.0 or dampening != 0.0): | |||||
| raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0," | |||||
| "but got momentum {}, dampening {}".format(momentum, dampening)) | |||||
| self.nesterov = nesterov | self.nesterov = nesterov | ||||
| self.opt = P.SGD(dampening, weight_decay, nesterov) | self.opt = P.SGD(dampening, weight_decay, nesterov) | ||||