Browse Source

refactor seed interfaces

tags/v1.1.0
Yi Huaijie 5 years ago
parent
commit
b28a6ff88e
6 changed files with 98 additions and 21 deletions
  1. +2
    -2
      mindspore/common/__init__.py
  2. +2
    -2
      mindspore/common/initializer.py
  3. +88
    -12
      mindspore/common/seed.py
  4. +2
    -2
      mindspore/nn/layer/basic.py
  5. +2
    -2
      mindspore/ops/composite/random_ops.py
  6. +2
    -1
      mindspore/parallel/_utils.py

+ 2
- 2
mindspore/common/__init__.py View File

@@ -18,7 +18,7 @@ from .api import ms_function
from .dtype import * from .dtype import *
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
from .seed import set_seed, _get_seed, get_global_seed
from .seed import set_seed, get_seed




__all__ = dtype.__all__ __all__ = dtype.__all__
@@ -27,5 +27,5 @@ __all__.extend([
'ms_function', # api 'ms_function', # api
'Parameter', 'ParameterTuple', # parameter 'Parameter', 'ParameterTuple', # parameter
"dtype", "dtype",
"set_seed", "_get_seed", "get_global_seed" # random seed
"set_seed", "get_seed" # random seed
]) ])

+ 2
- 2
mindspore/common/initializer.py View File

@@ -24,7 +24,7 @@ from mindspore import log as logger


from . import dtype as mstype from . import dtype as mstype
from .tensor import Tensor from .tensor import Tensor
from .seed import get_global_seed
from .seed import get_seed
from .._c_expression import random_normal from .._c_expression import random_normal


_INITIALIZER_ALIAS = dict() _INITIALIZER_ALIAS = dict()
@@ -89,7 +89,7 @@ class Initializer:
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)


global_seed = get_global_seed()
global_seed = get_seed()
need_set_seed = ((slice_index is not None) and (global_seed is None)) need_set_seed = ((slice_index is not None) and (global_seed is None))
seed_saved = np.random.get_state()[1][0] seed_saved = np.random.get_state()[1][0]
if need_set_seed: if need_set_seed:


+ 88
- 12
mindspore/common/seed.py View File

@@ -25,6 +25,15 @@ keyConstant = [3528531795, 2654435769, 3449720151, 3144134277]
_GLOBAL_SEED = None _GLOBAL_SEED = None
_KERNEL_SEED = {} _KERNEL_SEED = {}



def _reset_op_seed():
"""
Reset op seeds in the kernel's dictionary.
"""
for kernel_name, op_seed in _KERNEL_SEED.items():
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed


def set_seed(seed): def set_seed(seed):
""" """
Set global random seed. Set global random seed.
@@ -46,6 +55,81 @@ def set_seed(seed):
Raises: Raises:
ValueError: If seed is invalid (< 0). ValueError: If seed is invalid (< 0).
TypeError: If seed isn't a int. TypeError: If seed isn't a int.

Examples:
1. If global seed is not set, numpy.random and initializer will choose a random seed:
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
Rerun the program will get diferent results:
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A3
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A4
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W3
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W4

2. If global seed is set, numpy.random and initializer will use it:
>>> set_seed(1234)
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
Rerun the program will get the same results:
>>> set_seed(1234)
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
>>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2

3. If neither global seed nor op seed is set, mindspore.ops.composite.random_ops and
mindspore.nn.probability.distribution will choose a random seed:
>>> c1 = C.uniform((1, 4)) # C1
>>> c2 = C.uniform((1, 4)) # C2
Rerun the program will get different results:
>>> c1 = C.uniform((1, 4)) # C3
>>> c2 = C.uniform((1, 4)) # C4

4. If global seed is set, but op seed is not set, mindspore.ops.composite.random_ops and
mindspore.nn.probability.distribution will caculate a seed according to global seed and
default op seed. Each call will change the default op seed, thus each call get different
results.
>>> set_seed(1234)
>>> c1 = C.uniform((1, 4)) # C1
>>> c2 = C.uniform((1, 4)) # C2
Rerun the program will get the same results:
>>> set_seed(1234)
>>> c1 = C.uniform((1, 4)) # C1
>>> c2 = C.uniform((1, 4)) # C2

5. If both global seed and op seed are set, mindspore.ops.composite.random_ops and
mindspore.nn.probability.distribution will caculate a seed according to global seed and
op seed counter. Each call will change the op seed counter, thus each call get different
results.
>>> set_seed(1234)
>>> c1 = C.uniform((1, 4), seed=2) # C1
>>> c2 = C.uniform((1, 4), seed=2) # C2
Rerun the program will get the same results:
>>> set_seed(1234)
>>> c1 = C.uniform((1, 4), seed=2) # C1
>>> c2 = C.uniform((1, 4), seed=2) # C2

6. If op seed is set but global seed is not set, 0 will be used as global seed. Then
mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution act as in
condition 5.
>>> c1 = C.uniform((1, 4), seed=2) # C1
>>> c2 = C.uniform((1, 4), seed=2) # C2
Rerun the program will get the same results:
>>> c1 = C.uniform((1, 4), seed=2) # C1
>>> c2 = C.uniform((1, 4), seed=2) # C2

7. Recall set_seed() in the program will reset numpy seed and op seed counter of
mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
>>> set_seed(1234)
>>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
>>> c1 = C.uniform((1, 4), seed=2) # C1
>>> set_seed(1234)
>>> np_2 = np.random.normal(0, 1, [1]).astype(np.float32) # still get A1
>>> c2 = C.uniform((1, 4), seed=2) # still get C1
""" """
if not isinstance(seed, int): if not isinstance(seed, int):
raise TypeError("The seed must be type of int.") raise TypeError("The seed must be type of int.")
@@ -57,7 +141,7 @@ def set_seed(seed):
_GLOBAL_SEED = seed _GLOBAL_SEED = seed




def get_global_seed():
def get_seed():
""" """
Get global random seed. Get global random seed.
""" """
@@ -101,7 +185,7 @@ def _get_op_seed(op_seed, kernel_name):
return _KERNEL_SEED[(kernel_name, op_seed)] return _KERNEL_SEED[(kernel_name, op_seed)]




def _get_seed(op_seed, kernel_name):
def _get_graph_seed(op_seed, kernel_name):
""" """
Get the graph-level seed. Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
@@ -125,12 +209,12 @@ def _get_seed(op_seed, kernel_name):
Examples: Examples:
>>> _get_seed(seed, 'normal') >>> _get_seed(seed, 'normal')
""" """
global_seed = get_global_seed()
global_seed = get_seed()
if global_seed is None: if global_seed is None:
global_seed = 0 global_seed = 0
if op_seed is None: if op_seed is None:
op_seed = 0 op_seed = 0
# eigther global seed or op seed is set, return (0, 0) to let kernel choose random seed.
# neither global seed or op seed is set, return (0, 0) to let kernel choose random seed.
if global_seed == 0 and op_seed == 0: if global_seed == 0 and op_seed == 0:
seeds = 0, 0 seeds = 0, 0
else: else:
@@ -139,11 +223,3 @@ def _get_seed(op_seed, kernel_name):
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name) _update_seeds(op_seed, kernel_name)
return seeds return seeds


def _reset_op_seed():
"""
Reset op seeds in the kernel's dictionary.
"""
for (kernel_name, op_seed) in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed

+ 2
- 2
mindspore/nn/layer/basic.py View File

@@ -17,7 +17,7 @@


import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.seed import _get_seed
from mindspore.common.seed import _get_graph_seed
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
@@ -89,7 +89,7 @@ class Dropout(Cell):
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob self.keep_prob = keep_prob
seed0, seed1 = _get_seed(0, "dropout")
seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0 self.seed0 = seed0
self.seed1 = seed1 self.seed1 = seed1
self.dtype = dtype self.dtype = dtype


+ 2
- 2
mindspore/ops/composite/random_ops.py View File

@@ -19,12 +19,12 @@ from .. import operations as P
from .. import functional as F from .. import functional as F
from .multitype_ops import _constexpr_utils as const_utils from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common import _get_seed
from ...common.seed import _get_graph_seed


@constexpr @constexpr
def get_seed(op_seed, kernel_name): def get_seed(op_seed, kernel_name):
"Get the graph-level seed." "Get the graph-level seed."
return _get_seed(op_seed, kernel_name)
return _get_graph_seed(op_seed, kernel_name)




def normal(shape, mean, stddev, seed=None): def normal(shape, mean, stddev, seed=None):


+ 2
- 1
mindspore/parallel/_utils.py View File

@@ -22,6 +22,7 @@ from mindspore.common.dtype import dtype_to_nptype
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.common.seed import get_seed




def _get_parallel_mode(): def _get_parallel_mode():
@@ -139,7 +140,7 @@ def _get_parameter_broadcast():
parallel_mode = auto_parallel_context().get_parallel_mode() parallel_mode = auto_parallel_context().get_parallel_mode()
parameter_broadcast = auto_parallel_context().get_parameter_broadcast() parameter_broadcast = auto_parallel_context().get_parameter_broadcast()


if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False:
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
logger.warning("You are suggested to use mindspore.common.set_seed() to share" logger.warning("You are suggested to use mindspore.common.set_seed() to share"
" parameters among devices.") " parameters among devices.")




Loading…
Cancel
Save