Browse Source

Added notation for graph-level seed access interfaces

tags/v0.7.0-beta
peixu_ren 5 years ago
parent
commit
eb0cd8b4d3
1 changed files with 26 additions and 1 deletions
  1. +26
    -1
      mindspore/ops/composite/random_ops.py

+ 26
- 1
mindspore/ops/composite/random_ops.py View File

@@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._checkparam import Validator as validator
from ..._checkparam import check_int_positive
from ..._checkparam import Rel

# set graph-level RNG seed
@@ -29,11 +30,36 @@ _GRAPH_SEED = 0

@constexpr
def set_seed(seed):
"""
Set the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.

Args:
seed(Int): the graph-level seed value that to be set.

Examples:
>>> C.set_seed(10)
"""
check_int_positive(seed)
global _GRAPH_SEED
_GRAPH_SEED = seed

@constexpr
def get_seed():
"""
Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.

Returns:
Interger. The current graph-level seed.

Examples:
>>> C.get_seed(10)
"""
return _GRAPH_SEED


@@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0):
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5)
"""
mean_dtype = F.dtype(mean)


Loading…
Cancel
Save