Browse Source

!12 Fix dtype bug for loss_scale and weight_decay

Merge pull request !12 from seatea/dynamic-loss-scale
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
062b744b19
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      mindspore/nn/optim/optimizer.py
  2. +1
    -1
      mindspore/nn/wrap/loss_scale.py

+ 1
- 1
mindspore/nn/optim/optimizer.py View File

@@ -84,7 +84,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((gradient, weight * F.scalar_to_array(weight_decay)))
return op_add((gradient, weight * weight_decay))
return gradient




+ 1
- 1
mindspore/nn/wrap/loss_scale.py View File

@@ -32,7 +32,7 @@ reciprocal = P.Reciprocal()

@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
return grad * F.cast(reciprocal(scale), F.dtype(grad))


class DynamicLossScaleUpdateCell(Cell):


Loading…
Cancel
Save