|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -330,7 +330,7 @@ def _tensor_grad_overflow(grad): |
|
|
|
return grad_overflow(grad) |
|
|
|
|
|
|
|
|
|
|
|
class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): |
|
|
|
""" |
|
|
|
Encapsulation class of bert network training. |
|
|
|
|
|
|
|
@@ -344,39 +344,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None): |
|
|
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad = C.GradOperation(get_by_list=True, |
|
|
|
sens_param=True) |
|
|
|
self.reducer_flag = False |
|
|
|
self.allreduce = P.AllReduce() |
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: |
|
|
|
self.reducer_flag = True |
|
|
|
self.grad_reducer = F.identity |
|
|
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) |
|
|
|
self.cast = P.Cast() |
|
|
|
self.degree = 1 |
|
|
|
if self.reducer_flag: |
|
|
|
self.degree = get_group_size() |
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) |
|
|
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) |
|
|
|
self.cast = P.Cast() |
|
|
|
if context.get_context("device_target") == "GPU": |
|
|
|
self.gpu_target = True |
|
|
|
self.float_status = P.FloatStatus() |
|
|
|
self.addn = P.AddN() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
else: |
|
|
|
self.gpu_target = False |
|
|
|
self.alloc_status = P.NPUAllocFloatStatus() |
|
|
|
self.get_status = P.NPUGetFloatStatus() |
|
|
|
self.clear_status = P.NPUClearFloatStatus() |
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False) |
|
|
|
self.base = Tensor(1, mstype.float32) |
|
|
|
self.less_equal = P.LessEqual() |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
|
|
|
|
self.loss_scale = None |
|
|
|
self.loss_scaling_manager = scale_update_cell |
|
|
|
if scale_update_cell: |
|
|
|
@@ -404,13 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
else: |
|
|
|
scaling_sens = sens |
|
|
|
init = False |
|
|
|
if not self.gpu_target: |
|
|
|
# alloc status and clear should be right before gradoperation |
|
|
|
init = self.alloc_status() |
|
|
|
init = F.depend(init, loss) |
|
|
|
clear_status = self.clear_status(init) |
|
|
|
scaling_sens = F.depend(scaling_sens, clear_status) |
|
|
|
status, scaling_sens = self.start_overflow(loss, scaling_sens) |
|
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
@@ -424,21 +392,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
if not self.gpu_target: |
|
|
|
init = F.depend(init, grads) |
|
|
|
get_status = self.get_status(init) |
|
|
|
init = F.depend(init, get_status) |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
else: |
|
|
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) |
|
|
|
flag_sum = self.addn(flag_sum) |
|
|
|
flag_sum = self.reshape(flag_sum, (())) |
|
|
|
if self.is_distributed: |
|
|
|
# sum overflow flag over devices |
|
|
|
flag_reduce = self.allreduce(flag_sum) |
|
|
|
cond = self.less_equal(self.base, flag_reduce) |
|
|
|
else: |
|
|
|
cond = self.less_equal(self.base, flag_sum) |
|
|
|
|
|
|
|
cond = self.detect_overflow(status, grads) |
|
|
|
overflow = cond |
|
|
|
if sens is None: |
|
|
|
overflow = self.loss_scaling_manager(self.loss_scale, cond) |
|
|
|
@@ -449,7 +404,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
ret = (loss, cond, scaling_sens) |
|
|
|
return F.depend(ret, succ) |
|
|
|
|
|
|
|
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): |
|
|
|
|
|
|
|
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell): |
|
|
|
""" |
|
|
|
Encapsulation class of bert network training. |
|
|
|
|
|
|
|
@@ -464,40 +420,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): |
|
|
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None. |
|
|
|
""" |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None): |
|
|
|
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad = C.GradOperation(get_by_list=True, |
|
|
|
sens_param=True) |
|
|
|
self.reducer_flag = False |
|
|
|
self.allreduce = P.AllReduce() |
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: |
|
|
|
self.reducer_flag = True |
|
|
|
self.grad_reducer = F.identity |
|
|
|
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell) |
|
|
|
self.cast = P.Cast() |
|
|
|
self.degree = 1 |
|
|
|
if self.reducer_flag: |
|
|
|
self.degree = get_group_size() |
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) |
|
|
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) |
|
|
|
self.cast = P.Cast() |
|
|
|
if context.get_context("device_target") == "GPU": |
|
|
|
self.gpu_target = True |
|
|
|
self.float_status = P.FloatStatus() |
|
|
|
self.addn = P.AddN() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
else: |
|
|
|
self.gpu_target = False |
|
|
|
self.alloc_status = P.NPUAllocFloatStatus() |
|
|
|
self.get_status = P.NPUGetFloatStatus() |
|
|
|
self.clear_status = P.NPUClearFloatStatus() |
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False) |
|
|
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1) |
|
|
|
self.base = Tensor(1, mstype.float32) |
|
|
|
self.less_equal = P.LessEqual() |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.loss_scale = None |
|
|
|
self.loss_scaling_manager = scale_update_cell |
|
|
|
if scale_update_cell: |
|
|
|
@@ -525,14 +453,8 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
else: |
|
|
|
scaling_sens = sens |
|
|
|
init = False |
|
|
|
if not self.gpu_target: |
|
|
|
# alloc status and clear should be right before gradoperation |
|
|
|
init = self.alloc_status() |
|
|
|
init = F.depend(init, loss) |
|
|
|
clear_status = self.clear_status(init) |
|
|
|
scaling_sens = F.depend(scaling_sens, clear_status) |
|
|
|
|
|
|
|
status, scaling_sens = self.start_overflow(loss, scaling_sens) |
|
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
@@ -546,21 +468,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) |
|
|
|
if not self.gpu_target: |
|
|
|
init = F.depend(init, grads) |
|
|
|
get_status = self.get_status(init) |
|
|
|
init = F.depend(init, get_status) |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
else: |
|
|
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) |
|
|
|
flag_sum = self.addn(flag_sum) |
|
|
|
flag_sum = self.reshape(flag_sum, (())) |
|
|
|
if self.is_distributed: |
|
|
|
# sum overflow flag over devices |
|
|
|
flag_reduce = self.allreduce(flag_sum) |
|
|
|
cond = self.less_equal(self.base, flag_reduce) |
|
|
|
else: |
|
|
|
cond = self.less_equal(self.base, flag_sum) |
|
|
|
cond = self.detect_overflow(status, grads) |
|
|
|
overflow = cond |
|
|
|
if self.loss_scaling_manager is not None: |
|
|
|
overflow = self.loss_scaling_manager(scaling_sens, cond) |
|
|
|
|