| @@ -315,6 +315,15 @@ def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @_grad_overflow.register("Tensor") | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of bert network training. | |||
| @@ -347,9 +356,16 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| if context.get_context("device_target") == "GPU": | |||
| self.gpu_target = True | |||
| self.float_status = P.FloatStatus() | |||
| self.addn = P.AddN() | |||
| self.reshape = P.Reshape() | |||
| else: | |||
| self.gpu_target = False | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| @@ -383,9 +399,11 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| init = False | |||
| if not self.gpu_target: | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| @@ -399,8 +417,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| if not self.gpu_target: | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| else: | |||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||
| flag_sum = self.addn(flag_sum) | |||
| flag_sum = self.reshape(flag_sum, (())) | |||
| if self.is_distributed: | |||
| # sum overflow flag over devices | |||
| flag_reduce = self.allreduce(flag_sum) | |||