| @@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import check_int_positive | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| # set graph-level RNG seed | # set graph-level RNG seed | ||||
| @@ -29,11 +30,36 @@ _GRAPH_SEED = 0 | |||||
| @constexpr | @constexpr | ||||
| def set_seed(seed): | def set_seed(seed): | ||||
| """ | |||||
| Set the graph-level seed. | |||||
| Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. | |||||
| If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a | |||||
| random seed. | |||||
| Args: | |||||
| seed(Int): the graph-level seed value that to be set. | |||||
| Examples: | |||||
| >>> C.set_seed(10) | |||||
| """ | |||||
| check_int_positive(seed) | |||||
| global _GRAPH_SEED | global _GRAPH_SEED | ||||
| _GRAPH_SEED = seed | _GRAPH_SEED = seed | ||||
| @constexpr | @constexpr | ||||
| def get_seed(): | def get_seed(): | ||||
| """ | |||||
| Get the graph-level seed. | |||||
| Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. | |||||
| If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a | |||||
| random seed. | |||||
| Returns: | |||||
| Interger. The current graph-level seed. | |||||
| Examples: | |||||
| >>> C.get_seed(10) | |||||
| """ | |||||
| return _GRAPH_SEED | return _GRAPH_SEED | ||||
| @@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0): | |||||
| >>> shape = (4, 16) | >>> shape = (4, 16) | ||||
| >>> mean = Tensor(1.0, mstype.float32) | >>> mean = Tensor(1.0, mstype.float32) | ||||
| >>> stddev = Tensor(1.0, mstype.float32) | >>> stddev = Tensor(1.0, mstype.float32) | ||||
| >>> C.set_seed(10) | |||||
| >>> output = C.normal(shape, mean, stddev, seed=5) | >>> output = C.normal(shape, mean, stddev, seed=5) | ||||
| """ | """ | ||||
| mean_dtype = F.dtype(mean) | mean_dtype = F.dtype(mean) | ||||