| @@ -26,7 +26,7 @@ from mindspore.communication.management import get_group_size | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from utils import ClipByGlobalNorm | |||||
| from src.utils import ClipByGlobalNorm | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| GRADIENT_CLIP_VALUE = 1.0 | GRADIENT_CLIP_VALUE = 1.0 | ||||
| @@ -77,6 +77,7 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): | def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): | ||||
| super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | super(GPTTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | ||||
| self.network = network | self.network = network | ||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.enable_global_norm = enable_global_norm | self.enable_global_norm = enable_global_norm | ||||