Browse Source

!6272 fix bug in bert TrainOneStepCell

Merge pull request !6272 from wangnan39/fix_bug_in_bert
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
1e248c770f
2 changed files with 2 additions and 3 deletions
  1. +1
    -1
      model_zoo/official/nlp/bert/src/bert_for_pre_training.py
  2. +1
    -2
      model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py

+ 1
- 1
model_zoo/official/nlp/bert/src/bert_for_pre_training.py View File

@@ -258,7 +258,7 @@ class BertNetworkWithLoss(nn.Cell):
return self.cast(total_loss, mstype.float32)


class BertTrainOneStepCell(nn.Cell):
class BertTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of bert network training.



+ 1
- 2
model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py View File

@@ -275,7 +275,7 @@ class BertNetworkWithLoss(nn.Cell):
return self.cast(total_loss, mstype.float32)


class BertTrainOneStepCell(nn.Cell):
class BertTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of bert network training.

@@ -287,7 +287,6 @@ class BertTrainOneStepCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""

def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
self.cast = P.Cast()


Loading…
Cancel
Save