Browse Source

Fix code docs example result format

tags/v1.6.0
l00591931 4 years ago
parent
commit
907d1efcae
1 changed files with 12 additions and 5 deletions
  1. +12
    -5
      mindspore/nn/grad/cell_grad.py

+ 12
- 5
mindspore/nn/grad/cell_grad.py View File

@@ -91,9 +91,11 @@ class Jvp(Cell):
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = Jvp(Net())(x, y, (v, v))
>>> print(output[0])
[[2, 10], [30, 68]]
[[ 2. 10.]
[30. 68.]]
>>> print(output[1])
[[4, 13], [28, 49]]
[[ 4. 13.]
[28. 49.]]
"""
def __init__(self, fn):
super(Jvp, self).__init__()
@@ -203,9 +205,14 @@ class Vjp(Cell):
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = Vjp(Net())(x, y, v)
>>> print(output[0])
[[2, 10], [30, 68]]
>>> print(output[1])
([[3, 12], [27, 48]], [[1, 1], [1, 1]])
[[ 2. 10.]
[30. 68.]]
>>> print(output[1][0])
[[ 3. 12.]
[27. 48.]]
>>> print(output[1][1])
[[1. 1.]
[1. 1.]]
"""

def __init__(self, fn):


Loading…
Cancel
Save