|
|
|
@@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
sens=None): |
|
|
|
"""Defines the computation performed.""" |
|
|
|
weights = self.weights |
|
|
|
# alloc status |
|
|
|
init = self.alloc_status() |
|
|
|
self.clear_before_grad(init) |
|
|
|
loss = self.network(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
@@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
else: |
|
|
|
scaling_sens = sens |
|
|
|
# alloc status and clear should be right before gradoperation |
|
|
|
init = self.alloc_status() |
|
|
|
self.clear_before_grad(init) |
|
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
|
input_mask, |
|
|
|
token_type_id, |
|
|
|
|