GitOrigin-RevId: 7ed0447bfe
tags/v1.7.0
| @@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: | |||
| assert inp.size > 0, "size needs to be greater than 0" | |||
| op = ShuffleRNG(seed=seed, handle=handle) | |||
| output, _ = apply(op, inp) | |||
| inp._reset(output) | |||
| return output | |||
| class RNG: | |||
| @@ -554,12 +554,15 @@ class RNG: | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) | |||
| def permutation(self, n: int, *, dtype: str = "int32"): | |||
| r"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`. | |||
| def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"): | |||
| r"""Randomly permute a sequence, or return a permuted range. | |||
| If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index. | |||
| Args: | |||
| n: the upper bound. Must be larger than 0. | |||
| dtype: the output data type. int32, int16 and float32 are supported. Default: int32 | |||
| n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`. | |||
| If ``n`` is an tensor, make a copy and shuffle the elements randomly. | |||
| dtype: the output data type when ``n`` is an integer. | |||
| int32, int16 and float32 are supported. Default: int32 | |||
| Returns: | |||
| the output tensor. | |||
| @@ -568,13 +571,18 @@ class RNG: | |||
| .. testcode:: | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.random as rand | |||
| x = rand.permutation(n=10, dtype="int32") | |||
| x = rand.permutation(10, dtype="int32") | |||
| print(x.numpy()) | |||
| x = rand.permutation(10, dtype="float32") | |||
| print(x.numpy()) | |||
| x = rand.permutation(n=10, dtype="float32") | |||
| x = mge.tensor(np.arange(18)).reshape(6,3) | |||
| x = rand.permutation(x) | |||
| print(x.numpy()) | |||
| Outputs: | |||
| @@ -584,11 +592,20 @@ class RNG: | |||
| [4 5 0 7 3 8 6 1 9 2] | |||
| [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] | |||
| [[12 13 14] | |||
| [ 3 4 5] | |||
| [15 16 17] | |||
| [ 0 1 2] | |||
| [ 9 10 11] | |||
| [ 6 7 8]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| return _permutation( | |||
| n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||
| ) | |||
| if isinstance(n, int): | |||
| return _permutation( | |||
| n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||
| ) | |||
| assert isinstance(n, Tensor) | |||
| return _shuffle(inp=n, seed=_seed, handle=self._handle) | |||
| def shuffle(self, inp: Tensor): | |||
| r"""Modify a sequence in-place by shuffling its contents. | |||
| @@ -627,7 +644,7 @@ class RNG: | |||
| [ 6. 7. 8.]] | |||
| """ | |||
| _seed = self._seed() if callable(self._seed) else self._seed | |||
| _shuffle(inp=inp, seed=_seed, handle=self._handle) | |||
| inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle)) | |||
| def __del__(self): | |||
| if self._handle != 0: | |||
| @@ -28,6 +28,7 @@ from megengine.core.ops.builtin import ( | |||
| UniformRNG, | |||
| ) | |||
| from megengine.device import get_device_count | |||
| from megengine.jit import trace | |||
| from megengine.random import RNG | |||
| from megengine.random import seed as set_global_seed | |||
| from megengine.random import uniform | |||
| @@ -370,21 +371,22 @@ def test_PoissonRNG(): | |||
| @pytest.mark.skipif( | |||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
| ) | |||
| def test_PermutationRNG(): | |||
| @pytest.mark.parametrize("symbolic", [True, False]) | |||
| def test_PermutationRNG(symbolic): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.permutation(n=1000) | |||
| out1 = m1.permutation(1000) | |||
| out1_ = m1.uniform(size=(1000,)) | |||
| out2 = m2.permutation(n=1000) | |||
| out3 = m3.permutation(n=1000) | |||
| out2 = m2.permutation(1000) | |||
| out3 = m3.permutation(1000) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| out = m1.permutation(n=1000) | |||
| out = m1.permutation(1000) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (1000,) | |||
| @@ -397,6 +399,24 @@ def test_PermutationRNG(): | |||
| assert sum_result(out, lambda x: x) < 500 | |||
| assert sum_result(out, np.sort) == 1000 | |||
| def func(): | |||
| out = m1.permutation(Tensor(7)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (1,) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([1])) | |||
| n, m = 6, 3 | |||
| out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (n, m) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([n, m])) | |||
| func = trace(symbolic=symbolic)(func) | |||
| func() | |||
| @pytest.mark.skipif( | |||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||
| @@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { | |||
| add_input({data}); | |||
| add_output(None)->dtype(data->dtype()); | |||
| add_output(None)->dtype(dtype::Int32{}); | |||
| add_output(None) | |||
| ->dtype(data->dtype()) | |||
| .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| add_output(None) | |||
| ->dtype(dtype::Int32{}) | |||
| .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| cg::add_workspace_output(this); | |||
| add_equivalence_component<ScalarHash<void*>>(this); | |||
| } | |||
| @@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() { | |||
| }; | |||
| void ShuffleRNGForward::scn_do_execute() { | |||
| auto&& ret = output(0); | |||
| if (ret->layout().is_empty()) { | |||
| mgb_assert(ret->dev_tensor().empty()); | |||
| return; | |||
| } | |||
| m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| output(1)->dev_tensor().as_megdnn(), | |||
| get_megdnn_workspace_from_var(output(2))); | |||
| } | |||
| cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const { | |||
| auto prop = Super::do_make_node_prop(); | |||
| prop->add_flag(NodeProp::Flag::IMPURE_FUNC); | |||
| for (auto i : input()) { | |||
| prop->add_dep_type_existing_var(i, | |||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
| } | |||
| return prop; | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { | |||
| mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | |||