|
|
|
@@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode): |
|
|
|
func() |
|
|
|
|
|
|
|
worker() |
|
|
|
|
|
|
|
|
|
|
|
def test_2nd_grad_with_manager(): |
|
|
|
x_np = np.random.rand(10).astype("float32") |
|
|
|
x = mge.tensor(x_np) |
|
|
|
|
|
|
|
gm = GradManager().attach([x]) |
|
|
|
gm2 = GradManager().attach([x]) |
|
|
|
|
|
|
|
with gm: |
|
|
|
with gm2: |
|
|
|
y = F.cos(x) |
|
|
|
gm2.backward(y) |
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |
|
|
|
gm.backward(x.grad) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_manager_group(): |
|
|
|
x_np = np.random.rand(10).astype("float32") |
|
|
|
x = mge.tensor(x_np) |
|
|
|
|
|
|
|
gm = GradManager().attach([x]) |
|
|
|
gm2 = GradManager().attach([x]) |
|
|
|
|
|
|
|
with gm | gm2: |
|
|
|
y = F.cos(x) |
|
|
|
gm.backward(y) |
|
|
|
gm2.backward(y) |
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5) |
|
|
|
|
|
|
|
x.grad = None |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_manager_group_visibility(): |
|
|
|
x_np = np.random.rand(10).astype("float32") |
|
|
|
x = mge.tensor(x_np) |
|
|
|
|
|
|
|
gm = GradManager().attach([x]) |
|
|
|
gm2 = GradManager().attach([x]) |
|
|
|
|
|
|
|
with gm | gm2: |
|
|
|
y = F.cos(x) |
|
|
|
gm2.backward(y) |
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |
|
|
|
gm.backward(x.grad) |
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |
|
|
|
|
|
|
|
|
|
|
|
def test_grad_manager_visibility_by_order(): |
|
|
|
x_np = np.random.rand(10).astype("float32") |
|
|
|
x = mge.tensor(x_np) |
|
|
|
|
|
|
|
gm = GradManager().attach([x]) |
|
|
|
gm2 = GradManager().attach([x]) |
|
|
|
|
|
|
|
with gm2: |
|
|
|
with gm: |
|
|
|
y = F.cos(x) |
|
|
|
gm2.backward(y) |
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |
|
|
|
gm.backward(x.grad) |
|
|
|
|
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |