Browse Source

GPT scripts bug fix

tags/v1.1.0
alouhahaha 5 years ago
parent
commit
e5227c9f89
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      model_zoo/official/nlp/gpt/src/gpt_wrapcell.py

+ 2
- 1
model_zoo/official/nlp/gpt/src/gpt_wrapcell.py View File

@@ -26,7 +26,7 @@ from mindspore.communication.management import get_group_size
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from utils import ClipByGlobalNorm
from src.utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0 GRADIENT_CLIP_VALUE = 1.0
@@ -77,6 +77,7 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.enable_global_norm = enable_global_norm self.enable_global_norm = enable_global_norm


Loading…
Cancel
Save