| @@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| sens=None): | sens=None): | ||||
| """Defines the computation performed.""" | """Defines the computation performed.""" | ||||
| weights = self.weights | weights = self.weights | ||||
| # alloc status | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| loss = self.network(input_ids, | loss = self.network(input_ids, | ||||
| input_mask, | input_mask, | ||||
| token_type_id, | token_type_id, | ||||
| @@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| scaling_sens = self.loss_scale | scaling_sens = self.loss_scale | ||||
| else: | else: | ||||
| scaling_sens = sens | scaling_sens = sens | ||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| grads = self.grad(self.network, weights)(input_ids, | grads = self.grad(self.network, weights)(input_ids, | ||||
| input_mask, | input_mask, | ||||
| token_type_id, | token_type_id, | ||||