|
|
@@ -378,7 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): |
|
|
scaling_sens = self.loss_scale |
|
|
scaling_sens = self.loss_scale |
|
|
else: |
|
|
else: |
|
|
scaling_sens = sens |
|
|
scaling_sens = sens |
|
|
status, scaling_sens = self.start_overflow(loss, scaling_sens) |
|
|
|
|
|
|
|
|
status, scaling_sens = self.start_overflow_check(loss, scaling_sens) |
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
input_mask, |
|
|
input_mask, |
|
|
token_type_id, |
|
|
token_type_id, |
|
|
@@ -393,7 +393,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): |
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
|
|
|
|
|
cond = self.detect_overflow(status, grads) |
|
|
|
|
|
|
|
|
cond = self.get_overflow_status(status, grads) |
|
|
overflow = cond |
|
|
overflow = cond |
|
|
if sens is None: |
|
|
if sens is None: |
|
|
overflow = self.loss_scaling_manager(self.loss_scale, cond) |
|
|
overflow = self.loss_scaling_manager(self.loss_scale, cond) |
|
|
@@ -454,7 +454,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) |
|
|
else: |
|
|
else: |
|
|
scaling_sens = sens |
|
|
scaling_sens = sens |
|
|
|
|
|
|
|
|
status, scaling_sens = self.start_overflow(loss, scaling_sens) |
|
|
|
|
|
|
|
|
status, scaling_sens = self.start_overflow_check(loss, scaling_sens) |
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
input_mask, |
|
|
input_mask, |
|
|
token_type_id, |
|
|
token_type_id, |
|
|
@@ -468,7 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) |
|
|
grads = self.grad_reducer(grads) |
|
|
grads = self.grad_reducer(grads) |
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
cond = self.detect_overflow(status, grads) |
|
|
|
|
|
|
|
|
cond = self.get_overflow_status(status, grads) |
|
|
overflow = cond |
|
|
overflow = cond |
|
|
if self.loss_scaling_manager is not None: |
|
|
if self.loss_scaling_manager is not None: |
|
|
overflow = self.loss_scaling_manager(scaling_sens, cond) |
|
|
overflow = self.loss_scaling_manager(scaling_sens, cond) |
|
|
|