| @@ -27,7 +27,7 @@ from megengine.core.ops.builtin import ( | |||||
| UniformRNG, | UniformRNG, | ||||
| ) | ) | ||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.random import RNG | |||||
| from megengine.random import RNG, seed, uniform | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| @@ -387,3 +387,18 @@ def test_PermutationRNG(): | |||||
| assert sum_result(out, lambda x: x) < 500 | assert sum_result(out, lambda x: x) < 500 | ||||
| assert sum_result(out, np.sort) == 1000 | assert sum_result(out, np.sort) == 1000 | ||||
| def test_seed(): | |||||
| seed(10) | |||||
| out1 = uniform(size=[10, 10]) | |||||
| out2 = uniform(size=[10, 10]) | |||||
| assert not (out1.numpy() == out2.numpy()).all() | |||||
| seed(10) | |||||
| out3 = uniform(size=[10, 10]) | |||||
| np.testing.assert_equal(out1.numpy(), out3.numpy()) | |||||
| seed(11) | |||||
| out4 = uniform(size=[10, 10]) | |||||
| assert not (out1.numpy() == out4.numpy()).all() | |||||
| @@ -127,10 +127,8 @@ public: | |||||
| auto&& glob_handle = glob_default_handles[comp_node]; | auto&& glob_handle = glob_default_handles[comp_node]; | ||||
| if (!glob_handle) { | if (!glob_handle) { | ||||
| glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | ||||
| } else if (get_seed(glob_handle) != glob_default_seed) { | |||||
| inst().DnnOpManagerBase::delete_handle(glob_handle); | |||||
| glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||||
| } | } | ||||
| mgb_assert(get_seed(glob_handle) == glob_default_seed); | |||||
| return glob_handle; | return glob_handle; | ||||
| } | } | ||||
| @@ -141,6 +139,13 @@ public: | |||||
| static void set_glob_default_seed(uint64_t seed) { | static void set_glob_default_seed(uint64_t seed) { | ||||
| MGB_LOCK_GUARD(sm_mtx); | MGB_LOCK_GUARD(sm_mtx); | ||||
| for(auto && elem : glob_default_handles){ | |||||
| mgb_assert(elem.first.valid()); | |||||
| if(elem.second){ | |||||
| inst().DnnOpManagerBase::delete_handle(elem.second); | |||||
| } | |||||
| elem.second = inst().do_new_handle(elem.first, seed); | |||||
| } | |||||
| glob_default_seed = seed; | glob_default_seed = seed; | ||||
| } | } | ||||