|
|
|
@@ -15,8 +15,11 @@ |
|
|
|
|
|
|
|
"""Operations for random number generatos.""" |
|
|
|
|
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from .. import operations as P |
|
|
|
from .. import functional as F |
|
|
|
from ..primitive import constexpr |
|
|
|
from .multitype_ops import _constexpr_utils as const_utils |
|
|
|
from ...common import dtype as mstype |
|
|
|
|
|
|
|
# set graph-level RNG seed |
|
|
|
_GRAPH_SEED = 0 |
|
|
|
@@ -31,17 +34,17 @@ def get_seed(): |
|
|
|
return _GRAPH_SEED |
|
|
|
|
|
|
|
|
|
|
|
def normal(shape, mean, stddev, seed): |
|
|
|
def normal(shape, mean, stddev, seed=0): |
|
|
|
""" |
|
|
|
Generates random numbers according to the Normal (or Gaussian) random number distribution. |
|
|
|
It is defined as: |
|
|
|
|
|
|
|
Args: |
|
|
|
- **shape** (tuple) - The shape of random tensor to be generated. |
|
|
|
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak. |
|
|
|
shape (tuple): The shape of random tensor to be generated. |
|
|
|
mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak. |
|
|
|
With float32 data type. |
|
|
|
- **stddev** (Tensor) - The deviation σ distribution parameter. With float32 data type. |
|
|
|
- **seed** (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. |
|
|
|
stddev (Tensor): The deviation σ distribution parameter. With float32 data type. |
|
|
|
seed (int): Seed is used as entropy source for Random number engines generating pseudo-random numbers. |
|
|
|
Default: 0. |
|
|
|
|
|
|
|
Returns: |
|
|
|
@@ -52,9 +55,13 @@ def normal(shape, mean, stddev, seed): |
|
|
|
>>> shape = (4, 16) |
|
|
|
>>> mean = Tensor(1.0, mstype.float32) |
|
|
|
>>> stddev = Tensor(1.0, mstype.float32) |
|
|
|
>>> C.set_seed(10) |
|
|
|
>>> output = C.normal(shape, mean, stddev, seed=5) |
|
|
|
""" |
|
|
|
set_seed(10) |
|
|
|
mean_dtype = F.dtype(mean) |
|
|
|
stddev_dtype = F.dtype(stddev) |
|
|
|
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") |
|
|
|
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") |
|
|
|
seed1 = get_seed() |
|
|
|
seed2 = seed |
|
|
|
stdnormal = P.StandardNormal(seed1, seed2) |
|
|
|
|