|
|
|
@@ -195,12 +195,12 @@ class Optimizer(Cell): |
|
|
|
params = self.parameters |
|
|
|
if self.is_group: |
|
|
|
if self.exec_weight_decay: |
|
|
|
gradients = self.hyper_map(F.partial(_apply_decay), self.weight_decay, self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
else: |
|
|
|
if self.weight_decay > 0: |
|
|
|
gradients = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
|
|
|
|
return gradients |
|
|
|
|
|
|
|
@@ -479,10 +479,20 @@ class Optimizer(Cell): |
|
|
|
|
|
|
|
|
|
|
|
op_add = P.AddN() |
|
|
|
op_gather = P.GatherV2() |
|
|
|
|
|
|
|
_apply_decay = C.MultitypeFuncGraph("apply_decay") |
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tuple") |
|
|
|
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
weight = op_gather(weight, gradient[0], 0) |
|
|
|
return gradient[0], op_add((weight * weight_decay, gradient[1])), gradient[2] |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor") |
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
|