|
|
|
@@ -94,7 +94,7 @@ class LARS(Optimizer): |
|
|
|
self.learning_rate = optimizer.learning_rate |
|
|
|
self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip) |
|
|
|
self.reciprocal_scale = 1.0 / loss_scale |
|
|
|
self.weight_decay = weight_decay * loss_scale |
|
|
|
self.weight_decay = weight_decay |
|
|
|
self.cast = P.Cast() |
|
|
|
self.decay_flag = tuple(decay_filter(x) for x in self.parameters) |
|
|
|
self.lars_flag = tuple(lars_filter(x) for x in self.parameters) |
|
|
|
|