|
|
|
@@ -138,14 +138,14 @@ class Optimizer(Cell): |
|
|
|
if self.is_group: |
|
|
|
self.parameters = ParameterTuple(self.group_params) |
|
|
|
self.weight_decay = tuple(self.group_weight_decay) |
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float16) for x in self.group_weight_decay) |
|
|
|
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay) |
|
|
|
decay_filter = lambda x: x > 0 |
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) |
|
|
|
self.exec_weight_decay = any(self.decay_flags) |
|
|
|
else: |
|
|
|
self.parameters = ParameterTuple(parameters) |
|
|
|
self.weight_decay = weight_decay * loss_scale |
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float16) |
|
|
|
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) |
|
|
|
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name |
|
|
|
self.decay_flags = tuple(decay_filter(x) for x in self.parameters) |
|
|
|
self.exec_weight_decay = self.weight_decay > 0 |
|
|
|
@@ -156,8 +156,9 @@ class Optimizer(Cell): |
|
|
|
break |
|
|
|
ps_filter = lambda x: x.is_param_ps |
|
|
|
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) |
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float16) |
|
|
|
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) |
|
|
|
self.need_scale = loss_scale != 1.0 |
|
|
|
self.global_step_increase_tensor = Tensor(1, mstype.int32) |
|
|
|
self.param_length = len(self.parameters) |
|
|
|
self.map_ = C.Map() |
|
|
|
if context.get_auto_parallel_context("enable_parallel_optimizer"): |
|
|
|
@@ -441,7 +442,7 @@ class Optimizer(Cell): |
|
|
|
else: |
|
|
|
lr = self.learning_rate(self.global_step) |
|
|
|
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1)) |
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, self.global_step_increase_tensor)) |
|
|
|
return lr |
|
|
|
|
|
|
|
def get_lr_parameter(self, param): |
|
|
|
@@ -542,7 +543,7 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
indices = gradient.indices |
|
|
|
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values)) |
|
|
|
values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values)) |
|
|
|
shape = gradient.dense_shape |
|
|
|
return RowTensor(indices, values, shape) |
|
|
|
return gradient |
|
|
|
@@ -552,7 +553,7 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): |
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
return op_add((op_mul(weight, weight_decay), gradient)) |
|
|
|
return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient)) |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
@@ -564,17 +565,17 @@ def tensor_grad_scale(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
if scale == 1.0: |
|
|
|
return grad |
|
|
|
return op_mul(grad, scale) |
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad))) |
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor") |
|
|
|
def tensor_grad_scale_with_tensor(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
return op_mul(grad, scale) |
|
|
|
return op_mul(grad, F.cast(scale, F.dtype(grad))) |
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "RowTensor") |
|
|
|
def tensor_grad_scale_with_sparse(scale, grad): |
|
|
|
"""Get grad with scale.""" |
|
|
|
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) |
|
|
|
return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape) |
|
|
|
|
|
|
|
|
|
|
|
@_indices_deduplicate.register("RowTensor") |
|
|
|
|