| @@ -315,6 +315,15 @@ def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | return grad * reciprocal(scale) | ||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
| grad_overflow = P.FloatStatus() | |||||
| @_grad_overflow.register("Tensor") | |||||
| def _tensor_grad_overflow(grad): | |||||
| return grad_overflow(grad) | |||||
| class BertTrainOneStepWithLossScaleCell(nn.Cell): | class BertTrainOneStepWithLossScaleCell(nn.Cell): | ||||
| """ | """ | ||||
| Encapsulation class of bert network training. | Encapsulation class of bert network training. | ||||
| @@ -347,9 +356,16 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | ||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| if context.get_context("device_target") == "GPU": | |||||
| self.gpu_target = True | |||||
| self.float_status = P.FloatStatus() | |||||
| self.addn = P.AddN() | |||||
| self.reshape = P.Reshape() | |||||
| else: | |||||
| self.gpu_target = False | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | self.reduce_sum = P.ReduceSum(keep_dims=False) | ||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | self.depend_parameter_use = P.ControlDepend(depend_mode=1) | ||||
| self.base = Tensor(1, mstype.float32) | self.base = Tensor(1, mstype.float32) | ||||
| @@ -383,9 +399,11 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| scaling_sens = self.loss_scale | scaling_sens = self.loss_scale | ||||
| else: | else: | ||||
| scaling_sens = sens | scaling_sens = sens | ||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| init = False | |||||
| if not self.gpu_target: | |||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| grads = self.grad(self.network, weights)(input_ids, | grads = self.grad(self.network, weights)(input_ids, | ||||
| input_mask, | input_mask, | ||||
| token_type_id, | token_type_id, | ||||
| @@ -399,8 +417,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) | ||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | ||||
| self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| if not self.gpu_target: | |||||
| self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| else: | |||||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||||
| flag_sum = self.addn(flag_sum) | |||||
| flag_sum = self.reshape(flag_sum, (())) | |||||
| if self.is_distributed: | if self.is_distributed: | ||||
| # sum overflow flag over devices | # sum overflow flag over devices | ||||
| flag_reduce = self.allreduce(flag_sum) | flag_reduce = self.allreduce(flag_sum) | ||||