Browse Source

!7204 refactor get_seed() interface

Merge pull request !7204 from yihuaijie/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
80472a44f9
7 changed files with 59 additions and 58 deletions
  1. +2
    -2
      mindspore/common/__init__.py
  2. +2
    -2
      mindspore/common/initializer.py
  3. +44
    -6
      mindspore/common/seed.py
  4. +4
    -4
      mindspore/nn/layer/basic.py
  5. +2
    -5
      mindspore/nn/probability/distribution/distribution.py
  6. +4
    -37
      mindspore/ops/composite/random_ops.py
  7. +1
    -2
      mindspore/parallel/_utils.py

+ 2
- 2
mindspore/common/__init__.py View File

@@ -18,7 +18,7 @@ from .api import ms_function
from .dtype import *
from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
from .seed import set_seed, get_seed, _truncate_seed, _update_seeds, _get_op_seed
from .seed import set_seed, _get_seed, get_global_seed


__all__ = dtype.__all__
@@ -27,5 +27,5 @@ __all__.extend([
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype",
"set_seed", "get_seed", '_truncate_seed', '_update_seeds', '_get_op_seed' # random seed
"set_seed", "_get_seed", "get_global_seed" # random seed
])

+ 2
- 2
mindspore/common/initializer.py View File

@@ -24,7 +24,7 @@ from mindspore import log as logger

from . import dtype as mstype
from .tensor import Tensor
from .seed import get_seed
from .seed import get_global_seed
from .._c_expression import random_normal

_INITIALIZER_ALIAS = dict()
@@ -89,7 +89,7 @@ class Initializer:
logger.error(msg)
raise ValueError(msg)

global_seed = get_seed()
global_seed = get_global_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None))
seed_saved = np.random.get_state()[1][0]
if need_set_seed:


+ 44
- 6
mindspore/common/seed.py View File

@@ -15,6 +15,7 @@
"""Provide random seed api."""
import numpy as np
import mindspore.dataset as de
from mindspore._checkparam import Validator

# constants
_MAXINT32 = 2**31 - 1
@@ -48,8 +49,7 @@ def set_seed(seed):
"""
if not isinstance(seed, int):
raise TypeError("The seed must be type of int.")
if seed < 0:
raise ValueError("The seed must be greater or equal to 0.")
Validator.check_non_negative_int(seed, "seed", "global_seed")
np.random.seed(seed)
de.config.set_seed(seed)
_reset_op_seed()
@@ -57,7 +57,7 @@ def set_seed(seed):
_GLOBAL_SEED = seed


def get_seed():
def get_global_seed():
"""
Get global random seed.
"""
@@ -82,10 +82,7 @@ def _update_seeds(op_seed, kernel_name):
seed (int): The op-seed to be updated.
kernel_name (string): The random op kernel.
"""
global _GLOBAL_SEED
global _KERNEL_SEED
if _GLOBAL_SEED is not None:
_GLOBAL_SEED += keyConstant[1] + keyConstant[3] * (2**8)
if op_seed is not None:
_KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2])

@@ -103,6 +100,47 @@ def _get_op_seed(op_seed, kernel_name):
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed
return _KERNEL_SEED[(kernel_name, op_seed)]


def _get_seed(op_seed, kernel_name):
"""
Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
random seed.

Note:
For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed.
So, the state of the seed regarding to this op should be recorded.
A simple illustration should be:
If a random op is called twice within one program, the two results should be different:
print(C.uniform((1, 4), seed=1)) # generates 'A1'
print(C.uniform((1, 4), seed=1)) # generates 'A2'
If the same program runs again, it repeat the results:
print(C.uniform((1, 4), seed=1)) # generates 'A1'
print(C.uniform((1, 4), seed=1)) # generates 'A2'

Returns:
Interger. The current graph-level seed.

Examples:
>>> _get_seed(seed, 'normal')
"""
global_seed = get_global_seed()
if global_seed is None:
global_seed = 0
if op_seed is None:
op_seed = 0
# eigther global seed or op seed is set, return (0, 0) to let kernel choose random seed.
if global_seed == 0 and op_seed == 0:
seeds = 0, 0
else:
Validator.check_non_negative_int(op_seed, "seed", kernel_name)
temp_seed = _get_op_seed(op_seed, kernel_name)
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name)
return seeds


def _reset_op_seed():
"""
Reset op seeds in the kernel's dictionary.


+ 4
- 4
mindspore/nn/layer/basic.py View File

@@ -17,7 +17,7 @@

import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.seed import get_seed
from mindspore.common.seed import _get_seed
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
@@ -89,9 +89,9 @@ class Dropout(Cell):
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
seed0 = get_seed()
self.seed0 = seed0 if seed0 is not None else 0
self.seed1 = 0
seed0, seed1 = _get_seed(0, "dropout")
self.seed0 = seed0
self.seed1 = seed1
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)


+ 2
- 5
mindspore/nn/probability/distribution/distribution.py View File

@@ -17,7 +17,6 @@ from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore.common import get_seed
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util
from ._utils.utils import CheckTuple, CheckTensor
@@ -29,7 +28,7 @@ class Distribution(Cell):
Base class for all mathematical distributions.

Args:
seed (int): The seed is used in sampling. The global seed is used if it is None.
seed (int): The seed is used in sampling. 0 is used if it is None.
dtype (mindspore.dtype): The type of the event samples.
name (str): The name of the distribution.
param (dict): The parameters used to initialize the distribution.
@@ -59,9 +58,7 @@ class Distribution(Cell):
"""
super(Distribution, self).__init__()
if seed is None:
seed = get_seed()
if seed is None:
seed = 0
seed = 0
validator.check_value_type('name', name, [str], type(self).__name__)
validator.check_non_negative_int(seed, 'seed', name)



+ 4
- 37
mindspore/ops/composite/random_ops.py View File

@@ -14,51 +14,18 @@
# ============================================================================
"""Operations for random number generators."""

from mindspore._checkparam import Validator
from mindspore.ops.primitive import constexpr
from .. import operations as P
from .. import functional as F
from ..primitive import constexpr
from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common import get_seed as get_global_seed
from ...common import _truncate_seed, _update_seeds, _get_op_seed
from ...common import _get_seed

@constexpr
def get_seed(op_seed, kernel_name):
"""
Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
random seed.

Note:
For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed.
So, the state of the seed regarding to this op should be recorded.
A simple illustration should be:
If a random op is called twice within one program, the two results should be different:
print(C.uniform((1, 4), seed=1)) # generates 'A1'
print(C.uniform((1, 4), seed=1)) # generates 'A2'
If the same program runs again, it repeat the results:
print(C.uniform((1, 4), seed=1)) # generates 'A1'
print(C.uniform((1, 4), seed=1)) # generates 'A2'

Returns:
Interger. The current graph-level seed.
"Get the graph-level seed."
return _get_seed(op_seed, kernel_name)

Examples:
>>> C.get_seed(seed, 'normal')
"""
global_seed = get_global_seed()
if global_seed is None:
global_seed = 0
if op_seed is None:
temp_seed = _get_op_seed(0, kernel_name)
else:
Validator.check_non_negative_int(op_seed, "seed", kernel_name)
temp_seed = _get_op_seed(op_seed, kernel_name)
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name)
return seeds

def normal(shape, mean, stddev, seed=None):
"""


+ 1
- 2
mindspore/parallel/_utils.py View File

@@ -22,7 +22,6 @@ from mindspore.common.dtype import dtype_to_nptype
from mindspore.common import dtype as mstype
from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.common.seed import get_seed


def _get_parallel_mode():
@@ -140,7 +139,7 @@ def _get_parameter_broadcast():
parallel_mode = auto_parallel_context().get_parallel_mode()
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()

if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False:
logger.warning("You are suggested to use mindspore.common.set_seed() to share"
" parameters among devices.")



Loading…
Cancel
Save