|
|
|
@@ -556,11 +556,24 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): |
|
|
|
return F.depend(ret, succ) |
|
|
|
|
|
|
|
cast = P.Cast() |
|
|
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") |
|
|
|
add_grads = C.MultitypeFuncGraph("add_grads") |
|
|
|
|
|
|
|
|
|
|
|
@add_grads.register("Tensor", "Tensor") |
|
|
|
def _add_grads(accu_grad, grad): |
|
|
|
return accu_grad + cast(grad, mstype.float32) |
|
|
|
|
|
|
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") |
|
|
|
|
|
|
|
@update_accu_grads.register("Tensor", "Tensor") |
|
|
|
def _update_accu_grads(accu_grad, grad): |
|
|
|
succ = True |
|
|
|
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32))) |
|
|
|
|
|
|
|
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads") |
|
|
|
|
|
|
|
@accumulate_accu_grads.register("Tensor", "Tensor") |
|
|
|
def _accumulate_accu_grads(accu_grad, grad): |
|
|
|
succ = True |
|
|
|
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) |
|
|
|
|
|
|
|
@@ -575,13 +588,17 @@ def _reset_accu_grads(accu_grad): |
|
|
|
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) |
|
|
|
|
|
|
|
|
|
|
|
class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): |
|
|
|
""" |
|
|
|
Encapsulation class of bert network training. |
|
|
|
|
|
|
|
Append an optimizer to the training network after that the construct |
|
|
|
function can be called to create the backward graph. To mimic higher batch size, gradients are |
|
|
|
accumulated N times before weight update. |
|
|
|
function can be called to create the backward graph. |
|
|
|
|
|
|
|
To mimic higher batch size, gradients are accumulated N times before weight update. |
|
|
|
|
|
|
|
For distribution mode, allreduce will only be implemented in the weight updated step, |
|
|
|
i.e. the sub-step after gradients accumulated N times. |
|
|
|
|
|
|
|
Args: |
|
|
|
network (Cell): The training network. Note that loss function should have been added. |
|
|
|
@@ -591,7 +608,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
batch_size * accumulation_steps. Default: 1. |
|
|
|
""" |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): |
|
|
|
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
@@ -680,7 +697,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
self.cast(scaling_sens, |
|
|
|
mstype.float32)) |
|
|
|
|
|
|
|
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) |
|
|
|
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) |
|
|
|
mean_loss = F.depend(mean_loss, accu_succ) |
|
|
|
|
|
|
|
self.get_status(init) |
|
|
|
@@ -716,3 +733,151 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): |
|
|
|
|
|
|
|
ret = (mean_loss, overflow, scaling_sens) |
|
|
|
return F.depend(ret, succ) |
|
|
|
|
|
|
|
|
|
|
|
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell): |
|
|
|
""" |
|
|
|
Encapsulation class of bert network training. |
|
|
|
|
|
|
|
Append an optimizer to the training network after that the construct |
|
|
|
function can be called to create the backward graph. |
|
|
|
|
|
|
|
To mimic higher batch size, gradients are accumulated N times before weight update. |
|
|
|
|
|
|
|
For distribution mode, allreduce will be implemented after each sub-step and the trailing time |
|
|
|
will be overided by backend optimization pass. |
|
|
|
|
|
|
|
Args: |
|
|
|
network (Cell): The training network. Note that loss function should have been added. |
|
|
|
optimizer (Optimizer): Optimizer for updating the weights. |
|
|
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None. |
|
|
|
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = |
|
|
|
batch_size * accumulation_steps. Default: 1. |
|
|
|
""" |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): |
|
|
|
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.accumulation_steps = accumulation_steps |
|
|
|
self.enable_global_norm = enable_global_norm |
|
|
|
self.one = Tensor(np.array([1]).astype(np.int32)) |
|
|
|
self.zero = Tensor(np.array([0]).astype(np.int32)) |
|
|
|
self.local_step = Parameter(initializer(0, [1], mstype.int32)) |
|
|
|
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') |
|
|
|
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) |
|
|
|
self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) |
|
|
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
|
self.reducer_flag = False |
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: |
|
|
|
self.reducer_flag = True |
|
|
|
self.grad_reducer = F.identity |
|
|
|
self.degree = 1 |
|
|
|
if self.reducer_flag: |
|
|
|
self.degree = get_group_size() |
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) |
|
|
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) |
|
|
|
self.overflow_reducer = F.identity |
|
|
|
if self.is_distributed: |
|
|
|
self.overflow_reducer = P.AllReduce() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.alloc_status = P.NPUAllocFloatStatus() |
|
|
|
self.get_status = P.NPUGetFloatStatus() |
|
|
|
self.clear_before_grad = P.NPUClearFloatStatus() |
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False) |
|
|
|
self.base = Tensor(1, mstype.float32) |
|
|
|
self.less_equal = P.LessEqual() |
|
|
|
self.logical_or = P.LogicalOr() |
|
|
|
self.not_equal = P.NotEqual() |
|
|
|
self.select = P.Select() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.loss_scale = None |
|
|
|
self.loss_scaling_manager = scale_update_cell |
|
|
|
if scale_update_cell: |
|
|
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) |
|
|
|
|
|
|
|
@C.add_flags(has_effect=True) |
|
|
|
def construct(self, |
|
|
|
input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
next_sentence_labels, |
|
|
|
masked_lm_positions, |
|
|
|
masked_lm_ids, |
|
|
|
masked_lm_weights, |
|
|
|
sens=None): |
|
|
|
"""Defines the computation performed.""" |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
next_sentence_labels, |
|
|
|
masked_lm_positions, |
|
|
|
masked_lm_ids, |
|
|
|
masked_lm_weights) |
|
|
|
if sens is None: |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
else: |
|
|
|
scaling_sens = sens |
|
|
|
|
|
|
|
# update accumulation parameters |
|
|
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) |
|
|
|
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) |
|
|
|
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) |
|
|
|
mean_loss = self.accu_loss / self.local_step |
|
|
|
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) |
|
|
|
|
|
|
|
# alloc status and clear should be right before gradoperation |
|
|
|
init = self.alloc_status() |
|
|
|
self.clear_before_grad(init) |
|
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
next_sentence_labels, |
|
|
|
masked_lm_positions, |
|
|
|
masked_lm_ids, |
|
|
|
masked_lm_weights, |
|
|
|
self.cast(scaling_sens, |
|
|
|
mstype.float32)) |
|
|
|
|
|
|
|
|
|
|
|
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) |
|
|
|
scaling = scaling_sens * self.degree * self.accumulation_steps |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
|
|
|
|
self.get_status(init) |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
flag_reduce = self.overflow_reducer(flag_sum) |
|
|
|
overflow = self.less_equal(self.base, flag_reduce) |
|
|
|
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) |
|
|
|
accu_overflow = self.select(overflow, self.one, self.zero) |
|
|
|
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) |
|
|
|
overflow = self.reshape(overflow, (())) |
|
|
|
|
|
|
|
if is_accu_step: |
|
|
|
succ = False |
|
|
|
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) |
|
|
|
succ = F.depend(succ, accu_succ) |
|
|
|
else: |
|
|
|
if sens is None: |
|
|
|
overflow = self.loss_scaling_manager(self.loss_scale, overflow) |
|
|
|
if overflow: |
|
|
|
succ = False |
|
|
|
else: |
|
|
|
if self.enable_global_norm: |
|
|
|
grads = C.clip_by_global_norm(grads, 1.0, None) |
|
|
|
else: |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
|
|
|
|
succ = self.optimizer(grads) |
|
|
|
|
|
|
|
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) |
|
|
|
succ = F.depend(succ, accu_succ) |
|
|
|
|
|
|
|
ret = (mean_loss, overflow, scaling_sens) |
|
|
|
return F.depend(ret, succ) |