From 907d1efcaee5ea96fb694124ea45e073bcdc2b1e Mon Sep 17 00:00:00 2001 From: l00591931 Date: Mon, 6 Dec 2021 10:40:44 +0800 Subject: [PATCH] Fix code docs example result format --- mindspore/nn/grad/cell_grad.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/grad/cell_grad.py b/mindspore/nn/grad/cell_grad.py index 1c5f131675..faa435a982 100644 --- a/mindspore/nn/grad/cell_grad.py +++ b/mindspore/nn/grad/cell_grad.py @@ -91,9 +91,11 @@ class Jvp(Cell): >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> output = Jvp(Net())(x, y, (v, v)) >>> print(output[0]) - [[2, 10], [30, 68]] + [[ 2. 10.] + [30. 68.]] >>> print(output[1]) - [[4, 13], [28, 49]] + [[ 4. 13.] + [28. 49.]] """ def __init__(self, fn): super(Jvp, self).__init__() @@ -203,9 +205,14 @@ class Vjp(Cell): >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> output = Vjp(Net())(x, y, v) >>> print(output[0]) - [[2, 10], [30, 68]] - >>> print(output[1]) - ([[3, 12], [27, 48]], [[1, 1], [1, 1]]) + [[ 2. 10.] + [30. 68.]] + >>> print(output[1][0]) + [[ 3. 12.] + [27. 48.]] + >>> print(output[1][1]) + [[1. 1.] + [1. 1.]] """ def __init__(self, fn):