You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_grad_scaler.py 759 B

12345678910111213141516171819202122232425262728
  1. import numpy as np
  2. import megengine as mge
  3. from megengine.amp import GradScaler
  4. from megengine.autodiff import GradManager
  5. from megengine.jit import trace
  6. def test_grad_scaler():
  7. def f():
  8. gm = GradManager()
  9. scaler = GradScaler()
  10. x = mge.tensor(1.0)
  11. for _ in range(3):
  12. with gm:
  13. y = x + 1
  14. gm.attach(y)
  15. loss = y + 1
  16. scaler.backward(gm, loss, unscale_grad=False)
  17. np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
  18. scaler.unscale(gm.attached_tensors())
  19. np.testing.assert_equal(y.grad.numpy(), 1)
  20. # test handle None elements
  21. scaler.unscale(gm.attached_tensors())
  22. f()
  23. trace(f)()