|
|
|
@@ -25,6 +25,7 @@ from ...ops import operations as P |
|
|
|
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \ |
|
|
|
ControlDepend |
|
|
|
from ...common import dtype as mstype |
|
|
|
import mindspore.context as context |
|
|
|
|
|
|
|
_grad_scale = C.MultitypeFuncGraph("grad_scale") |
|
|
|
reciprocal = P.Reciprocal() |
|
|
|
@@ -34,6 +35,12 @@ reciprocal = P.Reciprocal() |
|
|
|
def tensor_grad_scale(scale, grad): |
|
|
|
return grad * F.cast(reciprocal(scale), F.dtype(grad)) |
|
|
|
|
|
|
|
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") |
|
|
|
grad_overflow = P.FloatStatus() |
|
|
|
|
|
|
|
@_grad_overflow.register("Tensor") |
|
|
|
def _tensor_grad_overflow(grad): |
|
|
|
return grad_overflow(grad) |
|
|
|
|
|
|
|
class DynamicLossScaleUpdateCell(Cell): |
|
|
|
r""" |
|
|
|
@@ -195,9 +202,15 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.alloc_status = NPUAllocFloatStatus() |
|
|
|
self.get_status = NPUGetFloatStatus() |
|
|
|
self.clear_status = NPUClearFloatStatus() |
|
|
|
if context.get_context("device_target") == "GPU": |
|
|
|
self.gpu_target = True |
|
|
|
self.float_status = P.FloatStatus() |
|
|
|
self.addn = P.AddN() |
|
|
|
else: |
|
|
|
self.gpu_target = False |
|
|
|
self.alloc_status = NPUAllocFloatStatus() |
|
|
|
self.get_status = NPUGetFloatStatus() |
|
|
|
self.clear_status = NPUClearFloatStatus() |
|
|
|
self.reduce_sum = ReduceSum(keep_dims=False) |
|
|
|
self.base = Tensor(1, mstype.float32) |
|
|
|
self.less_equal = LessEqual() |
|
|
|
@@ -222,10 +235,11 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
def construct(self, data, label, sens=None): |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(data, label) |
|
|
|
# init overflow buffer |
|
|
|
init = self.alloc_status() |
|
|
|
# clear overflow buffer |
|
|
|
self.clear_status(init) |
|
|
|
if not self.gpu_target: |
|
|
|
# init overflow buffer |
|
|
|
init = self.alloc_status() |
|
|
|
# clear overflow buffer |
|
|
|
self.clear_status(init) |
|
|
|
if sens is None: |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
else: |
|
|
|
@@ -235,10 +249,14 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
if self.reducer_flag: |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
# get the overflow buffer |
|
|
|
self.get_status(init) |
|
|
|
# sum overflow buffer elements, 0:not overflow , >0:overflow |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
if not self.gpu_target: |
|
|
|
# get the overflow buffer |
|
|
|
self.get_status(init) |
|
|
|
# sum overflow buffer elements, 0:not overflow , >0:overflow |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
else: |
|
|
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) |
|
|
|
flag_sum = self.addn(flag_sum) |
|
|
|
if self.is_distributed: |
|
|
|
# sum overflow flag over devices |
|
|
|
flag_reduce = self.allreduce(flag_sum) |
|
|
|
|