|
|
|
@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size |
|
|
|
from mindspore import context |
|
|
|
from mindspore.ops import _selected_ops |
|
|
|
from .bert_model import BertModel |
|
|
|
from .utils import ClipByGlobalNorm |
|
|
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1 |
|
|
|
GRADIENT_CLIP_VALUE = 1.0 |
|
|
|
@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
optimizer (Optimizer): Optimizer for updating the weights. |
|
|
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None. |
|
|
|
""" |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None): |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): |
|
|
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.enable_global_norm = enable_global_norm |
|
|
|
self.grad = C.GradOperation(get_by_list=True, |
|
|
|
sens_param=True) |
|
|
|
self.reducer_flag = False |
|
|
|
@@ -419,7 +421,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
if self.enable_global_norm: |
|
|
|
grads = ClipByGlobalNorm()(grads) |
|
|
|
else: |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
self.get_status(init) |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
if self.is_distributed: |
|
|
|
@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = |
|
|
|
batch_size * accumulation_steps. Default: 1. |
|
|
|
""" |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): |
|
|
|
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.accumulation_steps = accumulation_steps |
|
|
|
self.enable_global_norm = enable_global_norm |
|
|
|
self.one = Tensor(np.array([1]).astype(np.int32)) |
|
|
|
self.zero = Tensor(np.array([0]).astype(np.int32)) |
|
|
|
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") |
|
|
|
@@ -582,7 +588,10 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
grads = self.grad_reducer(self.accu_grads) |
|
|
|
scaling = scaling_sens * self.degree * self.accumulation_steps |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling), grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
if self.enable_global_norm: |
|
|
|
grads = ClipByGlobalNorm()(grad) |
|
|
|
else: |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
accu_overflow = self.overflow_reducer(accu_overflow) |
|
|
|
F.control_depend(grads, accu_overflow) |
|
|
|
overflow = self.less_equal(self.base, accu_overflow) |
|
|
|
|