You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 833 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import numpy as np
  2. from megengine import Tensor
  3. from megengine.experimental.traced_module import trace_module
  4. from megengine.module import Module as M
  5. class MyModule1(M):
  6. def forward(self, x):
  7. y = Tensor(x)
  8. y += 1
  9. x = x + 2
  10. return x, y
  11. class MyModule2(M):
  12. def forward(self, x):
  13. y = Tensor([1, x, 1])
  14. y += 1
  15. x = x + 2
  16. return x, y
  17. def test_trace_module():
  18. x = Tensor(1)
  19. m1 = MyModule1()
  20. tm1 = trace_module(m1, x)
  21. m2 = MyModule2()
  22. tm2 = trace_module(m2, x)
  23. inp = Tensor(2)
  24. gt = m1(inp)
  25. output = tm1(inp)
  26. for a, b in zip(output, gt):
  27. np.testing.assert_equal(a.numpy(), b.numpy())
  28. gt1 = m2(inp)
  29. output1 = tm2(inp)
  30. for a, b in zip(output1, gt1):
  31. np.testing.assert_equal(a.numpy(), b.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台