diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index 474345ee86..9f5355d901 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -18,7 +18,7 @@ from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor -from .seed import set_seed, get_seed, _truncate_seed, _update_seeds, _get_op_seed +from .seed import set_seed, _get_seed, get_global_seed __all__ = dtype.__all__ @@ -27,5 +27,5 @@ __all__.extend([ 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype", - "set_seed", "get_seed", '_truncate_seed', '_update_seeds', '_get_op_seed' # random seed + "set_seed", "_get_seed", "get_global_seed" # random seed ]) diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 24faa74ac0..acc6f46d33 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -24,7 +24,7 @@ from mindspore import log as logger from . import dtype as mstype from .tensor import Tensor -from .seed import get_seed +from .seed import get_global_seed from .._c_expression import random_normal _INITIALIZER_ALIAS = dict() @@ -89,7 +89,7 @@ class Initializer: logger.error(msg) raise ValueError(msg) - global_seed = get_seed() + global_seed = get_global_seed() need_set_seed = ((slice_index is not None) and (global_seed is None)) seed_saved = np.random.get_state()[1][0] if need_set_seed: diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index 484e81015f..ff61508743 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -15,6 +15,7 @@ """Provide random seed api.""" import numpy as np import mindspore.dataset as de +from mindspore._checkparam import Validator # constants _MAXINT32 = 2**31 - 1 @@ -48,8 +49,7 @@ def set_seed(seed): """ if not isinstance(seed, int): raise TypeError("The seed must be type of int.") - if seed < 0: - raise ValueError("The seed must be greater or equal to 0.") + Validator.check_non_negative_int(seed, "seed", "global_seed") np.random.seed(seed) de.config.set_seed(seed) _reset_op_seed() @@ -57,7 +57,7 @@ def set_seed(seed): _GLOBAL_SEED = seed -def get_seed(): +def get_global_seed(): """ Get global random seed. """ @@ -82,10 +82,7 @@ def _update_seeds(op_seed, kernel_name): seed (int): The op-seed to be updated. kernel_name (string): The random op kernel. """ - global _GLOBAL_SEED global _KERNEL_SEED - if _GLOBAL_SEED is not None: - _GLOBAL_SEED += keyConstant[1] + keyConstant[3] * (2**8) if op_seed is not None: _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2]) @@ -103,6 +100,47 @@ def _get_op_seed(op_seed, kernel_name): _KERNEL_SEED[(kernel_name, op_seed)] = op_seed return _KERNEL_SEED[(kernel_name, op_seed)] + +def _get_seed(op_seed, kernel_name): + """ + Get the graph-level seed. + Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. + If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a + random seed. + + Note: + For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed. + So, the state of the seed regarding to this op should be recorded. + A simple illustration should be: + If a random op is called twice within one program, the two results should be different: + print(C.uniform((1, 4), seed=1)) # generates 'A1' + print(C.uniform((1, 4), seed=1)) # generates 'A2' + If the same program runs again, it repeat the results: + print(C.uniform((1, 4), seed=1)) # generates 'A1' + print(C.uniform((1, 4), seed=1)) # generates 'A2' + + Returns: + Interger. The current graph-level seed. + + Examples: + >>> _get_seed(seed, 'normal') + """ + global_seed = get_global_seed() + if global_seed is None: + global_seed = 0 + if op_seed is None: + op_seed = 0 + # eigther global seed or op seed is set, return (0, 0) to let kernel choose random seed. + if global_seed == 0 and op_seed == 0: + seeds = 0, 0 + else: + Validator.check_non_negative_int(op_seed, "seed", kernel_name) + temp_seed = _get_op_seed(op_seed, kernel_name) + seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) + _update_seeds(op_seed, kernel_name) + return seeds + + def _reset_op_seed(): """ Reset op seeds in the kernel's dictionary. diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 8cee69aab7..b7dd52718b 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -17,7 +17,7 @@ import numpy as np import mindspore.common.dtype as mstype -from mindspore.common.seed import get_seed +from mindspore.common.seed import _get_seed from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.ops import operations as P @@ -89,9 +89,9 @@ class Dropout(Cell): Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) self.keep_prob = keep_prob - seed0 = get_seed() - self.seed0 = seed0 if seed0 is not None else 0 - self.seed1 = 0 + seed0, seed1 = _get_seed(0, "dropout") + self.seed0 = seed0 + self.seed1 = seed1 self.dtype = dtype self.get_shape = P.Shape() self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) diff --git a/mindspore/nn/probability/distribution/distribution.py b/mindspore/nn/probability/distribution/distribution.py index 2386f5e5b3..c121100ac2 100644 --- a/mindspore/nn/probability/distribution/distribution.py +++ b/mindspore/nn/probability/distribution/distribution.py @@ -17,7 +17,6 @@ from mindspore import context from mindspore.ops import operations as P from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator -from mindspore.common import get_seed from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ raise_not_implemented_util from ._utils.utils import CheckTuple, CheckTensor @@ -29,7 +28,7 @@ class Distribution(Cell): Base class for all mathematical distributions. Args: - seed (int): The seed is used in sampling. The global seed is used if it is None. + seed (int): The seed is used in sampling. 0 is used if it is None. dtype (mindspore.dtype): The type of the event samples. name (str): The name of the distribution. param (dict): The parameters used to initialize the distribution. @@ -59,9 +58,7 @@ class Distribution(Cell): """ super(Distribution, self).__init__() if seed is None: - seed = get_seed() - if seed is None: - seed = 0 + seed = 0 validator.check_value_type('name', name, [str], type(self).__name__) validator.check_non_negative_int(seed, 'seed', name) diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index a165065d5b..904852b4b9 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -14,51 +14,18 @@ # ============================================================================ """Operations for random number generators.""" -from mindspore._checkparam import Validator +from mindspore.ops.primitive import constexpr from .. import operations as P from .. import functional as F -from ..primitive import constexpr from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype -from ...common import get_seed as get_global_seed -from ...common import _truncate_seed, _update_seeds, _get_op_seed +from ...common import _get_seed @constexpr def get_seed(op_seed, kernel_name): - """ - Get the graph-level seed. - Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. - If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a - random seed. - - Note: - For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed. - So, the state of the seed regarding to this op should be recorded. - A simple illustration should be: - If a random op is called twice within one program, the two results should be different: - print(C.uniform((1, 4), seed=1)) # generates 'A1' - print(C.uniform((1, 4), seed=1)) # generates 'A2' - If the same program runs again, it repeat the results: - print(C.uniform((1, 4), seed=1)) # generates 'A1' - print(C.uniform((1, 4), seed=1)) # generates 'A2' - - Returns: - Interger. The current graph-level seed. + "Get the graph-level seed." + return _get_seed(op_seed, kernel_name) - Examples: - >>> C.get_seed(seed, 'normal') - """ - global_seed = get_global_seed() - if global_seed is None: - global_seed = 0 - if op_seed is None: - temp_seed = _get_op_seed(0, kernel_name) - else: - Validator.check_non_negative_int(op_seed, "seed", kernel_name) - temp_seed = _get_op_seed(op_seed, kernel_name) - seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) - _update_seeds(op_seed, kernel_name) - return seeds def normal(shape, mean, stddev, seed=None): """ diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index bb0b603187..cd7b5d4597 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -22,7 +22,6 @@ from mindspore.common.dtype import dtype_to_nptype from mindspore.common import dtype as mstype from mindspore.communication.management import get_group_size, get_rank from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.common.seed import get_seed def _get_parallel_mode(): @@ -140,7 +139,7 @@ def _get_parameter_broadcast(): parallel_mode = auto_parallel_context().get_parallel_mode() parameter_broadcast = auto_parallel_context().get_parameter_broadcast() - if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None: + if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False: logger.warning("You are suggested to use mindspore.common.set_seed() to share" " parameters among devices.")