You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rng.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.functional as F
  12. from megengine import Tensor, jit, random
  13. from megengine.core._imperative_rt import CompNode
  14. from megengine.core._imperative_rt.core2 import apply
  15. from megengine.core._imperative_rt.ops import (
  16. delete_rng_handle,
  17. get_global_rng_seed,
  18. new_rng_handle,
  19. )
  20. from megengine.core.ops.builtin import (
  21. BetaRNG,
  22. GammaRNG,
  23. GaussianRNG,
  24. PermutationRNG,
  25. PoissonRNG,
  26. UniformRNG,
  27. )
  28. from megengine.device import get_device_count
  29. from megengine.random import RNG
  30. from megengine.random import seed as set_global_seed
  31. from megengine.random import uniform
  32. @pytest.mark.skipif(
  33. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  34. )
  35. def test_gaussian_op():
  36. set_global_seed(1024)
  37. shape = (
  38. 8,
  39. 9,
  40. 11,
  41. 12,
  42. )
  43. shape = Tensor(shape, dtype="int32")
  44. op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32")
  45. (output,) = apply(op, shape)
  46. assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
  47. assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1
  48. assert str(output.device) == str(CompNode("xpux"))
  49. assert output.dtype == np.float32
  50. cn = CompNode("xpu2")
  51. seed = 233333
  52. h = new_rng_handle(cn, seed)
  53. op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h)
  54. (output,) = apply(op, shape)
  55. delete_rng_handle(h)
  56. assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
  57. assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1
  58. assert str(output.device) == str(cn)
  59. assert output.dtype == np.float32
  60. @pytest.mark.skipif(
  61. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  62. )
  63. def test_uniform_op():
  64. set_global_seed(1024)
  65. shape = (
  66. 8,
  67. 9,
  68. 11,
  69. 12,
  70. )
  71. shape = Tensor(shape, dtype="int32")
  72. op = UniformRNG(seed=get_global_rng_seed(), dtype="float32")
  73. (output,) = apply(op, shape)
  74. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  75. assert str(output.device) == str(CompNode("xpux"))
  76. assert output.dtype == np.float32
  77. cn = CompNode("xpu2")
  78. seed = 233333
  79. h = new_rng_handle(cn, seed)
  80. op = UniformRNG(seed=seed, dtype="float32", handle=h)
  81. (output,) = apply(op, shape)
  82. delete_rng_handle(h)
  83. assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
  84. assert str(output.device) == str(cn)
  85. assert output.dtype == np.float32
  86. @pytest.mark.skipif(
  87. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  88. )
  89. def test_gamma_op():
  90. set_global_seed(1024)
  91. _shape, _scale = 2, 0.8
  92. _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale
  93. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32")
  94. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32")
  95. op = GammaRNG(seed=get_global_rng_seed(), handle=0)
  96. (output,) = apply(op, shape, scale)
  97. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  98. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  99. assert str(output.device) == str(CompNode("xpux"))
  100. cn = CompNode("xpu2")
  101. seed = 233333
  102. h = new_rng_handle(cn, seed)
  103. shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2")
  104. scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2")
  105. op = GammaRNG(seed=seed, handle=h)
  106. (output,) = apply(op, shape, scale)
  107. delete_rng_handle(h)
  108. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  109. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  110. assert str(output.device) == str(cn)
  111. @pytest.mark.skipif(
  112. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  113. )
  114. def test_beta_op():
  115. set_global_seed(1024)
  116. _alpha, _beta = 2, 0.8
  117. _expected_mean = _alpha / (_alpha + _beta)
  118. _expected_std = np.sqrt(
  119. _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1))
  120. )
  121. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32")
  122. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32")
  123. op = BetaRNG(seed=get_global_rng_seed())
  124. (output,) = apply(op, alpha, beta)
  125. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  126. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  127. assert str(output.device) == str(CompNode("xpux"))
  128. cn = CompNode("xpu2")
  129. seed = 233333
  130. h = new_rng_handle(cn, seed)
  131. alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn)
  132. beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn)
  133. op = BetaRNG(seed=seed, handle=h)
  134. (output,) = apply(op, alpha, beta)
  135. delete_rng_handle(h)
  136. assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1
  137. assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1
  138. assert str(output.device) == str(cn)
  139. @pytest.mark.skipif(
  140. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  141. )
  142. def test_poisson_op():
  143. set_global_seed(1024)
  144. lam = F.full([8, 9, 11, 12], value=2, dtype="float32")
  145. op = PoissonRNG(seed=get_global_rng_seed())
  146. (output,) = apply(op, lam)
  147. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  148. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  149. assert str(output.device) == str(CompNode("xpux"))
  150. cn = CompNode("xpu2")
  151. seed = 233333
  152. h = new_rng_handle(cn, seed)
  153. lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn)
  154. op = PoissonRNG(seed=seed, handle=h)
  155. (output,) = apply(op, lam)
  156. delete_rng_handle(h)
  157. assert np.fabs(output.numpy().mean() - 2.0) < 1e-1
  158. assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1
  159. assert str(output.device) == str(cn)
  160. @pytest.mark.skipif(
  161. get_device_count("xpu") <= 2, reason="xpu counts need > 2",
  162. )
  163. def test_permutation_op():
  164. set_global_seed(1024)
  165. n = 1000
  166. def test_permutation_op_dtype(dtype):
  167. def sum_result(res, fun):
  168. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  169. shape = Tensor((n,), dtype="int32")
  170. op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype)
  171. (output,) = apply(op, shape)
  172. assert sum_result(output, lambda x: x) < 500
  173. assert sum_result(output, np.sort) == n
  174. assert str(output.device) == str(CompNode("xpux"))
  175. assert output.dtype == dtype
  176. cn = CompNode("xpu2")
  177. seed = 233333
  178. h = new_rng_handle(cn, seed)
  179. op = PermutationRNG(seed=seed, handle=h, dtype=dtype)
  180. (output,) = apply(op, shape)
  181. delete_rng_handle(h)
  182. assert sum_result(output, lambda x: x) < 500
  183. assert sum_result(output, np.sort) == n
  184. assert str(output.device) == str(cn)
  185. assert output.dtype == dtype
  186. test_permutation_op_dtype(np.float32)
  187. test_permutation_op_dtype(np.int32)
  188. test_permutation_op_dtype(np.int16)
  189. @pytest.mark.skipif(
  190. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  191. )
  192. def test_UniformRNG():
  193. m1 = RNG(seed=111, device="xpu0")
  194. m2 = RNG(seed=111, device="xpu1")
  195. m3 = RNG(seed=222, device="xpu0")
  196. out1 = m1.uniform(size=(100,))
  197. out1_ = m1.uniform(size=(100,))
  198. out2 = m2.uniform(size=(100,))
  199. out3 = m3.uniform(size=(100,))
  200. np.testing.assert_equal(out1.numpy(), out2.numpy())
  201. assert out1.device == "xpu0" and out2.device == "xpu1"
  202. assert not (out1.numpy() == out3.numpy()).all()
  203. assert not (out1.numpy() == out1_.numpy()).all()
  204. low = -234
  205. high = 123
  206. out = m1.uniform(low=low, high=high, size=(20, 30, 40))
  207. out_shp = out.shape
  208. if isinstance(out_shp, tuple):
  209. assert out_shp == (20, 30, 40)
  210. else:
  211. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  212. assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
  213. @pytest.mark.skipif(
  214. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  215. )
  216. def test_NormalRNG():
  217. m1 = RNG(seed=111, device="xpu0")
  218. m2 = RNG(seed=111, device="xpu1")
  219. m3 = RNG(seed=222, device="xpu0")
  220. out1 = m1.normal(size=(100,))
  221. out1_ = m1.uniform(size=(100,))
  222. out2 = m2.normal(size=(100,))
  223. out3 = m3.normal(size=(100,))
  224. np.testing.assert_equal(out1.numpy(), out2.numpy())
  225. assert out1.device == "xpu0" and out2.device == "xpu1"
  226. assert not (out1.numpy() == out3.numpy()).all()
  227. assert not (out1.numpy() == out1_.numpy()).all()
  228. mean = -1
  229. std = 2
  230. out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
  231. out_shp = out.shape
  232. if isinstance(out_shp, tuple):
  233. assert out_shp == (20, 30, 40)
  234. else:
  235. assert all(out.shape.numpy() == np.array([20, 30, 40]))
  236. assert np.abs(out.mean().numpy() - mean) / std < 0.1
  237. assert np.abs(np.std(out.numpy()) - std) < 0.1
  238. @pytest.mark.skipif(
  239. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  240. )
  241. def test_GammaRNG():
  242. m1 = RNG(seed=111, device="xpu0")
  243. m2 = RNG(seed=111, device="xpu1")
  244. m3 = RNG(seed=222, device="xpu0")
  245. out1 = m1.gamma(2, size=(100,))
  246. out1_ = m1.uniform(size=(100,))
  247. out2 = m2.gamma(2, size=(100,))
  248. out3 = m3.gamma(2, size=(100,))
  249. np.testing.assert_equal(out1.numpy(), out2.numpy())
  250. assert out1.device == "xpu0" and out2.device == "xpu1"
  251. assert not (out1.numpy() == out3.numpy()).all()
  252. assert not (out1.numpy() == out1_.numpy()).all()
  253. shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  254. scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  255. expected_mean = (shape * scale).numpy()
  256. expected_std = (F.sqrt(shape) * scale).numpy()
  257. out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40))
  258. out_shp = out.shape
  259. if isinstance(out_shp, tuple):
  260. assert out_shp == (20, 30, 40, 2, 3)
  261. else:
  262. assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3]))
  263. assert (
  264. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  265. ).mean() < 0.1
  266. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  267. @pytest.mark.skipif(
  268. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  269. )
  270. def test_BetaRNG():
  271. m1 = RNG(seed=111, device="xpu0")
  272. m2 = RNG(seed=111, device="xpu1")
  273. m3 = RNG(seed=222, device="xpu0")
  274. out1 = m1.beta(2, 1, size=(100,))
  275. out1_ = m1.uniform(size=(100,))
  276. out2 = m2.beta(2, 1, size=(100,))
  277. out3 = m3.beta(2, 1, size=(100,))
  278. np.testing.assert_equal(out1.numpy(), out2.numpy())
  279. assert out1.device == "xpu0" and out2.device == "xpu1"
  280. assert not (out1.numpy() == out3.numpy()).all()
  281. assert not (out1.numpy() == out1_.numpy()).all()
  282. alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0")
  283. beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0")
  284. expected_mean = (alpha / (alpha + beta)).numpy()
  285. expected_std = (
  286. F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1)))
  287. ).numpy()
  288. out = m1.beta(alpha=alpha, beta=beta, size=(20, 30))
  289. out_shp = out.shape
  290. if isinstance(out_shp, tuple):
  291. assert out_shp == (20, 30, 2, 3)
  292. else:
  293. assert all(out.shape.numpy() == np.array([20, 30, 2, 3]))
  294. assert (
  295. np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std
  296. ).mean() < 0.1
  297. assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1
  298. @pytest.mark.skipif(
  299. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  300. )
  301. def test_PoissonRNG():
  302. m1 = RNG(seed=111, device="xpu0")
  303. m2 = RNG(seed=111, device="xpu1")
  304. m3 = RNG(seed=222, device="xpu0")
  305. lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32)
  306. out1 = m1.poisson(lam.to("xpu0"), size=(100,))
  307. out2 = m2.poisson(lam.to("xpu1"), size=(100,))
  308. out3 = m3.poisson(lam.to("xpu0"), size=(100,))
  309. np.testing.assert_equal(out1.numpy(), out2.numpy())
  310. assert out1.device == "xpu0" and out2.device == "xpu1"
  311. assert not (out1.numpy() == out3.numpy()).all()
  312. out = m1.poisson(lam.to("xpu0"), size=(20, 30))
  313. out_shp = out.shape
  314. expected_shape = (20, 30) + lam._tuple_shape
  315. if isinstance(out_shp, tuple):
  316. assert out_shp == expected_shape
  317. else:
  318. assert all(out.shape.numpy() == np.array(expected_shape))
  319. lam = lam.numpy()
  320. assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1
  321. assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1
  322. @pytest.mark.skipif(
  323. get_device_count("xpu") <= 1, reason="xpu counts need > 1",
  324. )
  325. def test_PermutationRNG():
  326. m1 = RNG(seed=111, device="xpu0")
  327. m2 = RNG(seed=111, device="xpu1")
  328. m3 = RNG(seed=222, device="xpu0")
  329. out1 = m1.permutation(n=1000)
  330. out1_ = m1.uniform(size=(1000,))
  331. out2 = m2.permutation(n=1000)
  332. out3 = m3.permutation(n=1000)
  333. np.testing.assert_equal(out1.numpy(), out2.numpy())
  334. assert out1.device == "xpu0" and out2.device == "xpu1"
  335. assert not (out1.numpy() == out3.numpy()).all()
  336. assert not (out1.numpy() == out1_.numpy()).all()
  337. out = m1.permutation(n=1000)
  338. out_shp = out.shape
  339. if isinstance(out_shp, tuple):
  340. assert out_shp == (1000,)
  341. else:
  342. assert all(out.shape.numpy() == np.array([1000]))
  343. def sum_result(res, fun):
  344. return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))])
  345. assert sum_result(out, lambda x: x) < 500
  346. assert sum_result(out, np.sort) == 1000
  347. def test_seed():
  348. set_global_seed(10)
  349. out1 = uniform(size=[10, 10])
  350. out2 = uniform(size=[10, 10])
  351. assert not (out1.numpy() == out2.numpy()).all()
  352. set_global_seed(10)
  353. out3 = uniform(size=[10, 10])
  354. np.testing.assert_equal(out1.numpy(), out3.numpy())
  355. set_global_seed(11)
  356. out4 = uniform(size=[10, 10])
  357. assert not (out1.numpy() == out4.numpy()).all()
  358. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  359. def test_rng_empty_tensor(is_symbolic):
  360. set_global_seed(1024)
  361. shapes = [
  362. (0,),
  363. (0, 0, 0),
  364. (10, 0, 10),
  365. ]
  366. def fn(shape):
  367. o1 = random.uniform(0, 1, shape)
  368. o2 = random.normal(0, 1, shape)
  369. o3 = random.gamma(2, 1, shape)
  370. o4 = random.beta(2, 1, shape)
  371. o5 = random.poisson(2, shape)
  372. return o1, o2, o3, o4, o5
  373. for shape in shapes:
  374. if is_symbolic is not None:
  375. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  376. else:
  377. fn_ = fn
  378. for _ in range(3):
  379. outs = fn_(shape)
  380. for out in outs:
  381. np.testing.assert_equal(out.numpy().shape, shape)
  382. if is_symbolic is None:
  383. break
  384. def fn2(n):
  385. return random.permutation(n=n)
  386. if is_symbolic is not None:
  387. fn2 = jit.trace(symbolic=is_symbolic)(fn2)
  388. for _ in range(3):
  389. out = fn2(0)
  390. np.testing.assert_equal(out.numpy().shape, (0,))
  391. if is_symbolic is None:
  392. break

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台