| @@ -91,9 +91,11 @@ class Jvp(Cell): | |||||
| >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | ||||
| >>> output = Jvp(Net())(x, y, (v, v)) | >>> output = Jvp(Net())(x, y, (v, v)) | ||||
| >>> print(output[0]) | >>> print(output[0]) | ||||
| [[2, 10], [30, 68]] | |||||
| [[ 2. 10.] | |||||
| [30. 68.]] | |||||
| >>> print(output[1]) | >>> print(output[1]) | ||||
| [[4, 13], [28, 49]] | |||||
| [[ 4. 13.] | |||||
| [28. 49.]] | |||||
| """ | """ | ||||
| def __init__(self, fn): | def __init__(self, fn): | ||||
| super(Jvp, self).__init__() | super(Jvp, self).__init__() | ||||
| @@ -203,9 +205,14 @@ class Vjp(Cell): | |||||
| >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | ||||
| >>> output = Vjp(Net())(x, y, v) | >>> output = Vjp(Net())(x, y, v) | ||||
| >>> print(output[0]) | >>> print(output[0]) | ||||
| [[2, 10], [30, 68]] | |||||
| >>> print(output[1]) | |||||
| ([[3, 12], [27, 48]], [[1, 1], [1, 1]]) | |||||
| [[ 2. 10.] | |||||
| [30. 68.]] | |||||
| >>> print(output[1][0]) | |||||
| [[ 3. 12.] | |||||
| [27. 48.]] | |||||
| >>> print(output[1][1]) | |||||
| [[1. 1.] | |||||
| [1. 1.]] | |||||
| """ | """ | ||||
| def __init__(self, fn): | def __init__(self, fn): | ||||