|
|
|
@@ -27,13 +27,16 @@ from megengine.core.ops.builtin import ( |
|
|
|
UniformRNG, |
|
|
|
) |
|
|
|
from megengine.device import get_device_count |
|
|
|
from megengine.random import RNG, seed, uniform |
|
|
|
from megengine.random import RNG |
|
|
|
from megengine.random import seed as set_global_seed |
|
|
|
from megengine.random import uniform |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_gaussian_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
shape = ( |
|
|
|
8, |
|
|
|
9, |
|
|
|
@@ -64,6 +67,7 @@ def test_gaussian_op(): |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_uniform_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
shape = ( |
|
|
|
8, |
|
|
|
9, |
|
|
|
@@ -92,6 +96,7 @@ def test_uniform_op(): |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_gamma_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
_shape, _scale = 2, 0.8 |
|
|
|
_expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale |
|
|
|
|
|
|
|
@@ -120,6 +125,7 @@ def test_gamma_op(): |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_beta_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
_alpha, _beta = 2, 0.8 |
|
|
|
_expected_mean = _alpha / (_alpha + _beta) |
|
|
|
_expected_std = np.sqrt( |
|
|
|
@@ -151,6 +157,7 @@ def test_beta_op(): |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_poisson_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
lam = F.full([8, 9, 11, 12], value=2, dtype="float32") |
|
|
|
op = PoissonRNG(seed=get_global_rng_seed()) |
|
|
|
(output,) = apply(op, lam) |
|
|
|
@@ -174,6 +181,7 @@ def test_poisson_op(): |
|
|
|
get_device_count("xpu") <= 2, reason="xpu counts need > 2", |
|
|
|
) |
|
|
|
def test_permutation_op(): |
|
|
|
set_global_seed(1024) |
|
|
|
n = 1000 |
|
|
|
|
|
|
|
def test_permutation_op_dtype(dtype): |
|
|
|
@@ -390,22 +398,23 @@ def test_PermutationRNG(): |
|
|
|
|
|
|
|
|
|
|
|
def test_seed(): |
|
|
|
seed(10) |
|
|
|
set_global_seed(10) |
|
|
|
out1 = uniform(size=[10, 10]) |
|
|
|
out2 = uniform(size=[10, 10]) |
|
|
|
assert not (out1.numpy() == out2.numpy()).all() |
|
|
|
|
|
|
|
seed(10) |
|
|
|
set_global_seed(10) |
|
|
|
out3 = uniform(size=[10, 10]) |
|
|
|
np.testing.assert_equal(out1.numpy(), out3.numpy()) |
|
|
|
|
|
|
|
seed(11) |
|
|
|
set_global_seed(11) |
|
|
|
out4 = uniform(size=[10, 10]) |
|
|
|
assert not (out1.numpy() == out4.numpy()).all() |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_symbolic", [None, False, True]) |
|
|
|
def test_rng_empty_tensor(is_symbolic): |
|
|
|
set_global_seed(1024) |
|
|
|
shapes = [ |
|
|
|
(0,), |
|
|
|
(0, 0, 0), |
|
|
|
|