|
|
@@ -32,7 +32,7 @@ reciprocal = P.Reciprocal() |
|
|
|
|
|
|
|
|
@_grad_scale.register("Tensor", "Tensor") |
|
|
@_grad_scale.register("Tensor", "Tensor") |
|
|
def tensor_grad_scale(scale, grad): |
|
|
def tensor_grad_scale(scale, grad): |
|
|
return grad * reciprocal(scale) |
|
|
|
|
|
|
|
|
return grad * F.cast(reciprocal(scale), F.dtype(grad)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicLossScaleUpdateCell(Cell): |
|
|
class DynamicLossScaleUpdateCell(Cell): |
|
|
|