Browse Source

!13163 modify api detect_overflow name in TrainOneStepWithLossScaleCell

From: @wangnan39
Reviewed-by: @kingxian,@hwhewei
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa1fbc088c
2 changed files with 9 additions and 9 deletions
  1. +5
    -5
      mindspore/nn/wrap/loss_scale.py
  2. +4
    -4
      model_zoo/official/nlp/bert/src/bert_for_pre_training.py

+ 5
- 5
mindspore/nn/wrap/loss_scale.py View File

@@ -298,7 +298,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
loss = self.network(*inputs) loss = self.network(*inputs)
scaling_sens = self.scale_sense scaling_sens = self.scale_sense


status, scaling_sens = self.start_overflow(loss, scaling_sens)
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)


scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
@@ -307,7 +307,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)


# get the overflow buffer # get the overflow buffer
cond = self.detect_overflow(status, grads)
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond) overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize # if there is no overflow, do optimize
if not overflow: if not overflow:
@@ -322,7 +322,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
else: else:
raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))


def start_overflow(self, pre_cond, compute_input):
def start_overflow_check(self, pre_cond, compute_input):
""" """
Start floating-point overflow detection. Create and clear the overflow detection state. Start floating-point overflow detection. Create and clear the overflow detection state.


@@ -355,9 +355,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
compute_input = F.depend(compute_input, clear_status) compute_input = F.depend(compute_input, clear_status)
return status, compute_input return status, compute_input


def detect_overflow(self, status, compute_output):
def get_overflow_status(self, status, compute_output):
""" """
Detect floating-point overflow status.
Get floating-point overflow status.


Get overflow results after executing the target process for overflow detection. Get overflow results after executing the target process for overflow detection.




+ 4
- 4
model_zoo/official/nlp/bert/src/bert_for_pre_training.py View File

@@ -378,7 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
scaling_sens = sens scaling_sens = sens
status, scaling_sens = self.start_overflow(loss, scaling_sens)
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids, grads = self.grad(self.network, weights)(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
@@ -393,7 +393,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)


cond = self.detect_overflow(status, grads)
cond = self.get_overflow_status(status, grads)
overflow = cond overflow = cond
if sens is None: if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond) overflow = self.loss_scaling_manager(self.loss_scale, cond)
@@ -454,7 +454,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
else: else:
scaling_sens = sens scaling_sens = sens


status, scaling_sens = self.start_overflow(loss, scaling_sens)
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
grads = self.grad(self.network, weights)(input_ids, grads = self.grad(self.network, weights)(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
@@ -468,7 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
cond = self.detect_overflow(status, grads)
cond = self.get_overflow_status(status, grads)
overflow = cond overflow = cond
if self.loss_scaling_manager is not None: if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(scaling_sens, cond) overflow = self.loss_scaling_manager(scaling_sens, cond)


Loading…
Cancel
Save