|
|
|
@@ -298,7 +298,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
loss = self.network(*inputs) |
|
|
|
scaling_sens = self.scale_sense |
|
|
|
|
|
|
|
status, scaling_sens = self.start_overflow(loss, scaling_sens) |
|
|
|
status, scaling_sens = self.start_overflow_check(loss, scaling_sens) |
|
|
|
|
|
|
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) |
|
|
|
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) |
|
|
|
@@ -307,7 +307,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
|
|
|
|
# get the overflow buffer |
|
|
|
cond = self.detect_overflow(status, grads) |
|
|
|
cond = self.get_overflow_status(status, grads) |
|
|
|
overflow = self.process_loss_scale(cond) |
|
|
|
# if there is no overflow, do optimize |
|
|
|
if not overflow: |
|
|
|
@@ -322,7 +322,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
else: |
|
|
|
raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) |
|
|
|
|
|
|
|
def start_overflow(self, pre_cond, compute_input): |
|
|
|
def start_overflow_check(self, pre_cond, compute_input): |
|
|
|
""" |
|
|
|
Start floating-point overflow detection. Create and clear the overflow detection state. |
|
|
|
|
|
|
|
@@ -355,9 +355,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
compute_input = F.depend(compute_input, clear_status) |
|
|
|
return status, compute_input |
|
|
|
|
|
|
|
def detect_overflow(self, status, compute_output): |
|
|
|
def get_overflow_status(self, status, compute_output): |
|
|
|
""" |
|
|
|
Detect floating-point overflow status. |
|
|
|
Get floating-point overflow status. |
|
|
|
|
|
|
|
Get overflow results after executing the target process for overflow detection. |
|
|
|
|
|
|
|
|