|
|
|
@@ -88,14 +88,12 @@ class Optimizer(Cell): |
|
|
|
if isinstance(weight_decay, int): |
|
|
|
weight_decay = float(weight_decay) |
|
|
|
|
|
|
|
if not isinstance(weight_decay, float): |
|
|
|
raise TypeError("weight_decay should be a float number!") |
|
|
|
validator.check_float_legal_value('weight_decay', weight_decay, None) |
|
|
|
|
|
|
|
if isinstance(loss_scale, int): |
|
|
|
loss_scale = float(loss_scale) |
|
|
|
|
|
|
|
if not isinstance(loss_scale, float): |
|
|
|
raise TypeError("loss_scale should be a float number!") |
|
|
|
validator.check_float_legal_value('loss_scale', loss_scale, None) |
|
|
|
|
|
|
|
if loss_scale <= 0.0: |
|
|
|
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) |
|
|
|
|