Browse Source

!3863 Modify logic

Merge pull request !3863 from lijiaqi/bug
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
45ae76e86a
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/nn/optim/sgd.py

+ 4
- 4
mindspore/nn/optim/sgd.py View File

@@ -134,10 +134,6 @@ class SGD(Optimizer):
if isinstance(momentum, float) and momentum < 0.0: if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))


if nesterov and (momentum <= 0.0 or dampening != 0.0):
raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0,"
"but got momentum {}, dampening {}".format(momentum, dampening))

if isinstance(dampening, int): if isinstance(dampening, int):
dampening = float(dampening) dampening = float(dampening)
if not isinstance(dampening, float): if not isinstance(dampening, float):
@@ -151,6 +147,10 @@ class SGD(Optimizer):
weight_decay = float(weight_decay) weight_decay = float(weight_decay)


validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)

if nesterov and (momentum <= 0.0 or dampening != 0.0):
raise ValueError("If use nesterov, momentum must be positive and dampening must equal to 0.0,"
"but got momentum {}, dampening {}".format(momentum, dampening))
self.nesterov = nesterov self.nesterov = nesterov


self.opt = P.SGD(dampening, weight_decay, nesterov) self.opt = P.SGD(dampening, weight_decay, nesterov)


Loading…
Cancel
Save