GitOrigin-RevId: 9a851cd177
tags/v1.1.0
| @@ -52,7 +52,8 @@ def normal( | |||||
| size = (1,) | size = (1,) | ||||
| seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
| op = GaussianRNG(seed=seed, mean=mean, std=std) | op = GaussianRNG(seed=seed, mean=mean, std=std) | ||||
| size = Tensor(size, dtype="int32") | |||||
| _ref = Tensor([], dtype="int32") | |||||
| size = utils.astensor1d(size, _ref, dtype="int32") | |||||
| (output,) = apply(op, size) | (output,) = apply(op, size) | ||||
| return output | return output | ||||
| @@ -93,7 +94,8 @@ def uniform( | |||||
| size = (1,) | size = (1,) | ||||
| seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
| op = UniformRNG(seed=seed) | op = UniformRNG(seed=seed) | ||||
| size = Tensor(size, dtype="int32") | |||||
| _ref = Tensor([], dtype="int32") | |||||
| size = utils.astensor1d(size, _ref, dtype="int32") | |||||
| (output,) = apply(op, size) | (output,) = apply(op, size) | ||||
| return low + (high - low) * output | return low + (high - low) * output | ||||
| @@ -23,6 +23,7 @@ from megengine.core.tensor.core import apply | |||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
| from megengine.jit import exclude_from_trace, trace | from megengine.jit import exclude_from_trace, trace | ||||
| from megengine.random import normal, uniform | |||||
| def test_trace(): | def test_trace(): | ||||
| @@ -431,3 +432,23 @@ def test_slice(): | |||||
| y = f(x) | y = f(x) | ||||
| np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) | ||||
| y + y | y + y | ||||
| def test_random(): | |||||
| def run_test(op): | |||||
| for symbolic_shape in [True, False]: | |||||
| @trace(symbolic=True, symbolic_shape=symbolic_shape) | |||||
| def f(): | |||||
| out = op(size=[10, 10]) | |||||
| out_shape = out.shape | |||||
| assert out_shape is not None | |||||
| if not isinstance(out_shape, tuple): | |||||
| assert out.shape.numpy() is not None | |||||
| return out | |||||
| for _ in range(3): | |||||
| f() | |||||
| run_test(uniform) | |||||
| run_test(normal) | |||||