Merge pull request !7204 from yihuaijie/mastertags/v1.1.0
| @@ -18,7 +18,7 @@ from .api import ms_function | |||
| from .dtype import * | |||
| from .parameter import Parameter, ParameterTuple | |||
| from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor | |||
| from .seed import set_seed, get_seed, _truncate_seed, _update_seeds, _get_op_seed | |||
| from .seed import set_seed, _get_seed, get_global_seed | |||
| __all__ = dtype.__all__ | |||
| @@ -27,5 +27,5 @@ __all__.extend([ | |||
| 'ms_function', # api | |||
| 'Parameter', 'ParameterTuple', # parameter | |||
| "dtype", | |||
| "set_seed", "get_seed", '_truncate_seed', '_update_seeds', '_get_op_seed' # random seed | |||
| "set_seed", "_get_seed", "get_global_seed" # random seed | |||
| ]) | |||
| @@ -24,7 +24,7 @@ from mindspore import log as logger | |||
| from . import dtype as mstype | |||
| from .tensor import Tensor | |||
| from .seed import get_seed | |||
| from .seed import get_global_seed | |||
| from .._c_expression import random_normal | |||
| _INITIALIZER_ALIAS = dict() | |||
| @@ -89,7 +89,7 @@ class Initializer: | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| global_seed = get_seed() | |||
| global_seed = get_global_seed() | |||
| need_set_seed = ((slice_index is not None) and (global_seed is None)) | |||
| seed_saved = np.random.get_state()[1][0] | |||
| if need_set_seed: | |||
| @@ -15,6 +15,7 @@ | |||
| """Provide random seed api.""" | |||
| import numpy as np | |||
| import mindspore.dataset as de | |||
| from mindspore._checkparam import Validator | |||
| # constants | |||
| _MAXINT32 = 2**31 - 1 | |||
| @@ -48,8 +49,7 @@ def set_seed(seed): | |||
| """ | |||
| if not isinstance(seed, int): | |||
| raise TypeError("The seed must be type of int.") | |||
| if seed < 0: | |||
| raise ValueError("The seed must be greater or equal to 0.") | |||
| Validator.check_non_negative_int(seed, "seed", "global_seed") | |||
| np.random.seed(seed) | |||
| de.config.set_seed(seed) | |||
| _reset_op_seed() | |||
| @@ -57,7 +57,7 @@ def set_seed(seed): | |||
| _GLOBAL_SEED = seed | |||
| def get_seed(): | |||
| def get_global_seed(): | |||
| """ | |||
| Get global random seed. | |||
| """ | |||
| @@ -82,10 +82,7 @@ def _update_seeds(op_seed, kernel_name): | |||
| seed (int): The op-seed to be updated. | |||
| kernel_name (string): The random op kernel. | |||
| """ | |||
| global _GLOBAL_SEED | |||
| global _KERNEL_SEED | |||
| if _GLOBAL_SEED is not None: | |||
| _GLOBAL_SEED += keyConstant[1] + keyConstant[3] * (2**8) | |||
| if op_seed is not None: | |||
| _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2]) | |||
| @@ -103,6 +100,47 @@ def _get_op_seed(op_seed, kernel_name): | |||
| _KERNEL_SEED[(kernel_name, op_seed)] = op_seed | |||
| return _KERNEL_SEED[(kernel_name, op_seed)] | |||
| def _get_seed(op_seed, kernel_name): | |||
| """ | |||
| Get the graph-level seed. | |||
| Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. | |||
| If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a | |||
| random seed. | |||
| Note: | |||
| For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed. | |||
| So, the state of the seed regarding to this op should be recorded. | |||
| A simple illustration should be: | |||
| If a random op is called twice within one program, the two results should be different: | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A1' | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A2' | |||
| If the same program runs again, it repeat the results: | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A1' | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A2' | |||
| Returns: | |||
| Interger. The current graph-level seed. | |||
| Examples: | |||
| >>> _get_seed(seed, 'normal') | |||
| """ | |||
| global_seed = get_global_seed() | |||
| if global_seed is None: | |||
| global_seed = 0 | |||
| if op_seed is None: | |||
| op_seed = 0 | |||
| # eigther global seed or op seed is set, return (0, 0) to let kernel choose random seed. | |||
| if global_seed == 0 and op_seed == 0: | |||
| seeds = 0, 0 | |||
| else: | |||
| Validator.check_non_negative_int(op_seed, "seed", kernel_name) | |||
| temp_seed = _get_op_seed(op_seed, kernel_name) | |||
| seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) | |||
| _update_seeds(op_seed, kernel_name) | |||
| return seeds | |||
| def _reset_op_seed(): | |||
| """ | |||
| Reset op seeds in the kernel's dictionary. | |||
| @@ -17,7 +17,7 @@ | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.seed import get_seed | |||
| from mindspore.common.seed import _get_seed | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops import operations as P | |||
| @@ -89,9 +89,9 @@ class Dropout(Cell): | |||
| Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) | |||
| self.keep_prob = keep_prob | |||
| seed0 = get_seed() | |||
| self.seed0 = seed0 if seed0 is not None else 0 | |||
| self.seed1 = 0 | |||
| seed0, seed1 = _get_seed(0, "dropout") | |||
| self.seed0 = seed0 | |||
| self.seed1 = seed1 | |||
| self.dtype = dtype | |||
| self.get_shape = P.Shape() | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | |||
| @@ -17,7 +17,6 @@ from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.common import get_seed | |||
| from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ | |||
| raise_not_implemented_util | |||
| from ._utils.utils import CheckTuple, CheckTensor | |||
| @@ -29,7 +28,7 @@ class Distribution(Cell): | |||
| Base class for all mathematical distributions. | |||
| Args: | |||
| seed (int): The seed is used in sampling. The global seed is used if it is None. | |||
| seed (int): The seed is used in sampling. 0 is used if it is None. | |||
| dtype (mindspore.dtype): The type of the event samples. | |||
| name (str): The name of the distribution. | |||
| param (dict): The parameters used to initialize the distribution. | |||
| @@ -59,9 +58,7 @@ class Distribution(Cell): | |||
| """ | |||
| super(Distribution, self).__init__() | |||
| if seed is None: | |||
| seed = get_seed() | |||
| if seed is None: | |||
| seed = 0 | |||
| seed = 0 | |||
| validator.check_value_type('name', name, [str], type(self).__name__) | |||
| validator.check_non_negative_int(seed, 'seed', name) | |||
| @@ -14,51 +14,18 @@ | |||
| # ============================================================================ | |||
| """Operations for random number generators.""" | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.ops.primitive import constexpr | |||
| from .. import operations as P | |||
| from .. import functional as F | |||
| from ..primitive import constexpr | |||
| from .multitype_ops import _constexpr_utils as const_utils | |||
| from ...common import dtype as mstype | |||
| from ...common import get_seed as get_global_seed | |||
| from ...common import _truncate_seed, _update_seeds, _get_op_seed | |||
| from ...common import _get_seed | |||
| @constexpr | |||
| def get_seed(op_seed, kernel_name): | |||
| """ | |||
| Get the graph-level seed. | |||
| Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. | |||
| If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a | |||
| random seed. | |||
| Note: | |||
| For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed. | |||
| So, the state of the seed regarding to this op should be recorded. | |||
| A simple illustration should be: | |||
| If a random op is called twice within one program, the two results should be different: | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A1' | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A2' | |||
| If the same program runs again, it repeat the results: | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A1' | |||
| print(C.uniform((1, 4), seed=1)) # generates 'A2' | |||
| Returns: | |||
| Interger. The current graph-level seed. | |||
| "Get the graph-level seed." | |||
| return _get_seed(op_seed, kernel_name) | |||
| Examples: | |||
| >>> C.get_seed(seed, 'normal') | |||
| """ | |||
| global_seed = get_global_seed() | |||
| if global_seed is None: | |||
| global_seed = 0 | |||
| if op_seed is None: | |||
| temp_seed = _get_op_seed(0, kernel_name) | |||
| else: | |||
| Validator.check_non_negative_int(op_seed, "seed", kernel_name) | |||
| temp_seed = _get_op_seed(op_seed, kernel_name) | |||
| seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) | |||
| _update_seeds(op_seed, kernel_name) | |||
| return seeds | |||
| def normal(shape, mean, stddev, seed=None): | |||
| """ | |||
| @@ -22,7 +22,6 @@ from mindspore.common.dtype import dtype_to_nptype | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.communication.management import get_group_size, get_rank | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.common.seed import get_seed | |||
| def _get_parallel_mode(): | |||
| @@ -140,7 +139,7 @@ def _get_parameter_broadcast(): | |||
| parallel_mode = auto_parallel_context().get_parallel_mode() | |||
| parameter_broadcast = auto_parallel_context().get_parameter_broadcast() | |||
| if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None: | |||
| if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False: | |||
| logger.warning("You are suggested to use mindspore.common.set_seed() to share" | |||
| " parameters among devices.") | |||