|
|
|
@@ -352,6 +352,7 @@ class TrainOneStepCell(Cell): |
|
|
|
weights = self.weights |
|
|
|
loss = self.network(*inputs) |
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) |
|
|
|
sens = F.depend(sens, loss) |
|
|
|
grads = self.grad(self.network, weights)(*inputs, sens) |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
loss = F.depend(loss, self.optimizer(grads)) |
|
|
|
|