|
|
|
@@ -91,9 +91,11 @@ class Jvp(Cell): |
|
|
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) |
|
|
|
>>> output = Jvp(Net())(x, y, (v, v)) |
|
|
|
>>> print(output[0]) |
|
|
|
[[2, 10], [30, 68]] |
|
|
|
[[ 2. 10.] |
|
|
|
[30. 68.]] |
|
|
|
>>> print(output[1]) |
|
|
|
[[4, 13], [28, 49]] |
|
|
|
[[ 4. 13.] |
|
|
|
[28. 49.]] |
|
|
|
""" |
|
|
|
def __init__(self, fn): |
|
|
|
super(Jvp, self).__init__() |
|
|
|
@@ -203,9 +205,14 @@ class Vjp(Cell): |
|
|
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) |
|
|
|
>>> output = Vjp(Net())(x, y, v) |
|
|
|
>>> print(output[0]) |
|
|
|
[[2, 10], [30, 68]] |
|
|
|
>>> print(output[1]) |
|
|
|
([[3, 12], [27, 48]], [[1, 1], [1, 1]]) |
|
|
|
[[ 2. 10.] |
|
|
|
[30. 68.]] |
|
|
|
>>> print(output[1][0]) |
|
|
|
[[ 3. 12.] |
|
|
|
[27. 48.]] |
|
|
|
>>> print(output[1][1]) |
|
|
|
[[1. 1.] |
|
|
|
[1. 1.]] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, fn): |
|
|
|
|