From 7188a14215d2ceecb30b11116bdae325bb5c9adf Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Thu, 11 Mar 2021 15:44:11 +0800 Subject: [PATCH] modify api detect_overflow name in TrainOneStepWithLossScaleCell --- mindspore/nn/wrap/loss_scale.py | 10 +++++----- .../official/nlp/bert/src/bert_for_pre_training.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 793d8eb409..e15a5a69a6 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -298,7 +298,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): loss = self.network(*inputs) scaling_sens = self.scale_sense - status, scaling_sens = self.start_overflow(loss, scaling_sens) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) @@ -307,7 +307,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): grads = self.grad_reducer(grads) # get the overflow buffer - cond = self.detect_overflow(status, grads) + cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) # if there is no overflow, do optimize if not overflow: @@ -322,7 +322,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): else: raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) - def start_overflow(self, pre_cond, compute_input): + def start_overflow_check(self, pre_cond, compute_input): """ Start floating-point overflow detection. Create and clear the overflow detection state. @@ -355,9 +355,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): compute_input = F.depend(compute_input, clear_status) return status, compute_input - def detect_overflow(self, status, compute_output): + def get_overflow_status(self, status, compute_output): """ - Detect floating-point overflow status. + Get floating-point overflow status. Get overflow results after executing the target process for overflow detection. diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index 314647f2c5..1c15dcd0ec 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -378,7 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): scaling_sens = self.loss_scale else: scaling_sens = sens - status, scaling_sens = self.start_overflow(loss, scaling_sens) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -393,7 +393,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - cond = self.detect_overflow(status, grads) + cond = self.get_overflow_status(status, grads) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) @@ -454,7 +454,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) else: scaling_sens = sens - status, scaling_sens = self.start_overflow(loss, scaling_sens) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -468,7 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell) grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - cond = self.detect_overflow(status, grads) + cond = self.get_overflow_status(status, grads) overflow = cond if self.loss_scaling_manager is not None: overflow = self.loss_scaling_manager(scaling_sens, cond)