Browse Source

!77 fix bug of bert pre training

Merge pull request !77 from amongo/FixBugOfBertPreTraining
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
43adf281a2
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py

+ 3
- 3
mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py View File

@@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
sens=None):
"""Defines the computation performed."""
weights = self.weights
# alloc status
init = self.alloc_status()
self.clear_before_grad(init)
loss = self.network(input_ids,
input_mask,
token_type_id,
@@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,


Loading…
Cancel
Save