diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 003e257149..b143a736bc 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -352,6 +352,7 @@ class TrainOneStepCell(Cell): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + sens = F.depend(sens, loss) grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad_reducer(grads) loss = F.depend(loss, self.optimizer(grads))