Browse Source

GPT scripts bug fix

tags/v1.1.0
alouhahaha 5 years ago
parent
commit
e5227c9f89
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      model_zoo/official/nlp/gpt/src/gpt_wrapcell.py

+ 2
- 1
model_zoo/official/nlp/gpt/src/gpt_wrapcell.py View File

@@ -26,7 +26,7 @@ from mindspore.communication.management import get_group_size
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.parameter import Parameter
from utils import ClipByGlobalNorm
from src.utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
@@ -77,6 +77,7 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.enable_global_norm = enable_global_norm


Loading…
Cancel
Save