|
|
|
@@ -135,18 +135,21 @@ class Optimizer(Cell): |
|
|
|
if self.is_group: |
|
|
|
self.parameters = ParameterTuple(self.group_params) |
|
|
|
self.weight_decay = tuple(self.group_weight_decay) |
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay) |
|
|
|
decay_filter = lambda x: x > 0 |
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) |
|
|
|
self.exec_weight_decay = any(self.decay_flags) |
|
|
|
else: |
|
|
|
self.parameters = ParameterTuple(parameters) |
|
|
|
self.weight_decay = weight_decay * loss_scale |
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) |
|
|
|
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name |
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.parameters) |
|
|
|
self.exec_weight_decay = self.weight_decay > 0 |
|
|
|
ps_filter = lambda x: x.is_param_ps |
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) |
|
|
|
self.reciprocal_scale = 1.0 / loss_scale |
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) |
|
|
|
self.need_scale = loss_scale != 1.0 |
|
|
|
self.param_length = len(self.parameters) |
|
|
|
self.map_ = C.Map() |
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"): |
|
|
|
@@ -215,10 +218,10 @@ class Optimizer(Cell): |
|
|
|
if self.exec_weight_decay: |
|
|
|
params = self.parameters |
|
|
|
if self.is_group: |
|
|
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, |
|
|
|
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
else: |
|
|
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, |
|
|
|
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags, |
|
|
|
params, gradients) |
|
|
|
|
|
|
|
return gradients |
|
|
|
@@ -238,7 +241,7 @@ class Optimizer(Cell): |
|
|
|
tuple[Tensor], The gradients after loss scale. |
|
|
|
|
|
|
|
""" |
|
|
|
if self.reciprocal_scale != 1.0: |
|
|
|
if self.need_scale: |
|
|
|
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) |
|
|
|
|
|
|
|
return gradients |
|
|
|
@@ -522,11 +525,12 @@ class Optimizer(Cell): |
|
|
|
|
|
|
|
op_add = P.AddN() |
|
|
|
op_gather = P.GatherV2() |
|
|
|
op_mul = P.Mul() |
|
|
|
|
|
|
|
_apply_decay = C.MultitypeFuncGraph("apply_decay") |
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor") |
|
|
|
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor") |
|
|
|
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
@@ -537,11 +541,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
@_apply_decay.register("Number", "Bool", "Tensor", "Tensor") |
|
|
|
@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor") |
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
return op_add((weight * weight_decay, gradient)) |
|
|
|
return op_add((op_mul(weight, weight_decay), gradient)) |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
@@ -553,14 +557,16 @@ def tensor_grad_scale(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
if scale == 1.0: |
|
|
|
return grad |
|
|
|
return grad * scale |
|
|
|
return op_mul(grad, scale) |
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor") |
|
|
|
def tensor_grad_scale_with_tensor(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
return op_mul(grad, scale) |
|
|
|
|
|
|
|
@_grad_scale.register("Number", "RowTensor") |
|
|
|
@_grad_scale.register("Tensor", "RowTensor") |
|
|
|
def tensor_grad_scale_with_sparse(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
if scale == 1.0: |
|
|
|
return grad |
|
|
|
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) |
|
|
|
|
|
|
|
|
|
|
|
|