| @@ -298,7 +298,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| loss = self.network(*inputs) | |||
| scaling_sens = self.scale_sense | |||
| status, scaling_sens = self.start_overflow(loss, scaling_sens) | |||
| status, scaling_sens = self.start_overflow_check(loss, scaling_sens) | |||
| scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
| grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) | |||
| @@ -307,7 +307,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| grads = self.grad_reducer(grads) | |||
| # get the overflow buffer | |||
| cond = self.detect_overflow(status, grads) | |||
| cond = self.get_overflow_status(status, grads) | |||
| overflow = self.process_loss_scale(cond) | |||
| # if there is no overflow, do optimize | |||
| if not overflow: | |||
| @@ -322,7 +322,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| else: | |||
| raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) | |||
| def start_overflow(self, pre_cond, compute_input): | |||
| def start_overflow_check(self, pre_cond, compute_input): | |||
| """ | |||
| Start floating-point overflow detection. Create and clear the overflow detection state. | |||
| @@ -355,9 +355,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| compute_input = F.depend(compute_input, clear_status) | |||
| return status, compute_input | |||
| def detect_overflow(self, status, compute_output): | |||
| def get_overflow_status(self, status, compute_output): | |||
| """ | |||
| Detect floating-point overflow status. | |||
| Get floating-point overflow status. | |||
| Get overflow results after executing the target process for overflow detection. | |||
| @@ -378,7 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| status, scaling_sens = self.start_overflow(loss, scaling_sens) | |||
| status, scaling_sens = self.start_overflow_check(loss, scaling_sens) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| @@ -393,7 +393,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| cond = self.detect_overflow(status, grads) | |||
| cond = self.get_overflow_status(status, grads) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||
| @@ -454,7 +454,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) | |||
| else: | |||
| scaling_sens = sens | |||
| status, scaling_sens = self.start_overflow(loss, scaling_sens) | |||
| status, scaling_sens = self.start_overflow_check(loss, scaling_sens) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| @@ -468,7 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| cond = self.detect_overflow(status, grads) | |||
| cond = self.get_overflow_status(status, grads) | |||
| overflow = cond | |||
| if self.loss_scaling_manager is not None: | |||
| overflow = self.loss_scaling_manager(scaling_sens, cond) | |||