|
|
|
@@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.tensor import Tensor |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import check_int_positive |
|
|
|
from ..._checkparam import Rel |
|
|
|
|
|
|
|
# set graph-level RNG seed |
|
|
|
@@ -29,11 +30,36 @@ _GRAPH_SEED = 0 |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def set_seed(seed): |
|
|
|
""" |
|
|
|
Set the graph-level seed. |
|
|
|
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. |
|
|
|
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a |
|
|
|
random seed. |
|
|
|
|
|
|
|
Args: |
|
|
|
seed(Int): the graph-level seed value that to be set. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> C.set_seed(10) |
|
|
|
""" |
|
|
|
check_int_positive(seed) |
|
|
|
global _GRAPH_SEED |
|
|
|
_GRAPH_SEED = seed |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def get_seed(): |
|
|
|
""" |
|
|
|
Get the graph-level seed. |
|
|
|
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. |
|
|
|
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a |
|
|
|
random seed. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Interger. The current graph-level seed. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> C.get_seed(10) |
|
|
|
""" |
|
|
|
return _GRAPH_SEED |
|
|
|
|
|
|
|
|
|
|
|
@@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0): |
|
|
|
>>> shape = (4, 16) |
|
|
|
>>> mean = Tensor(1.0, mstype.float32) |
|
|
|
>>> stddev = Tensor(1.0, mstype.float32) |
|
|
|
>>> C.set_seed(10) |
|
|
|
>>> output = C.normal(shape, mean, stddev, seed=5) |
|
|
|
""" |
|
|
|
mean_dtype = F.dtype(mean) |
|
|
|
|