Browse Source

Check seed of non negative

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
24b0741796
2 changed files with 4 additions and 9 deletions
  1. +3
    -4
      mindspore/common/seed.py
  2. +1
    -5
      mindspore/ops/composite/random_ops.py

+ 3
- 4
mindspore/common/seed.py View File

@@ -97,14 +97,13 @@ def _get_op_seed(op_seed, kernel_name):
seed (int): The op-seed to be updated. seed (int): The op-seed to be updated.
kernel_name (string): The random op kernel. kernel_name (string): The random op kernel.
""" """
if ((kernel_name, op_seed) not in _KERNEL_SEED) or (_KERNEL_SEED[(kernel_name, op_seed)] == -1):
if (kernel_name, op_seed) not in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed _KERNEL_SEED[(kernel_name, op_seed)] = op_seed
_KERNEL_SEED[(kernel_name, op_seed)] = 0
return _KERNEL_SEED[(kernel_name, op_seed)] return _KERNEL_SEED[(kernel_name, op_seed)]


def _reset_op_seed(): def _reset_op_seed():
""" """
Reset op seeds in the kernel's dictionary. Reset op seeds in the kernel's dictionary.
""" """
for key in _KERNEL_SEED:
_KERNEL_SEED[key] = -1
for (kernel_name, op_seed) in _KERNEL_SEED:
_KERNEL_SEED[(kernel_name, op_seed)] = op_seed

+ 1
- 5
mindspore/ops/composite/random_ops.py View File

@@ -54,6 +54,7 @@ def get_seed(op_seed, kernel_name):
if op_seed is None: if op_seed is None:
temp_seed = _get_op_seed(0, kernel_name) temp_seed = _get_op_seed(0, kernel_name)
else: else:
const_utils.check_int_non_negative("seed", op_seed, kernel_name)
temp_seed = _get_op_seed(op_seed, kernel_name) temp_seed = _get_op_seed(op_seed, kernel_name)
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
_update_seeds(op_seed, kernel_name) _update_seeds(op_seed, kernel_name)
@@ -88,7 +89,6 @@ def normal(shape, mean, stddev, seed=None):
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal") const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "normal")
const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal") const_utils.check_tensors_dtype_same(stddev_dtype, mstype.float32, "normal")
seed1, seed2 = get_seed(seed, "normal") seed1, seed2 = get_seed(seed, "normal")
const_utils.check_int_non_negative("seed", seed2, "normal")
stdnormal = P.StandardNormal(seed1, seed2) stdnormal = P.StandardNormal(seed1, seed2)
random_normal = stdnormal(shape) random_normal = stdnormal(shape)
value = random_normal * stddev + mean value = random_normal * stddev + mean
@@ -126,7 +126,6 @@ def laplace(shape, mean, lambda_param, seed=None):
const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace") const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace")
const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace") const_utils.check_tensors_dtype_same(lambda_param_dtype, mstype.float32, "laplace")
seed1, seed2 = get_seed(seed, "laplace") seed1, seed2 = get_seed(seed, "laplace")
const_utils.check_int_non_negative("seed", seed2, "laplace")
stdlaplace = P.StandardLaplace(seed1, seed2) stdlaplace = P.StandardLaplace(seed1, seed2)
rnd = stdlaplace(shape) rnd = stdlaplace(shape)
value = rnd * lambda_param + mean value = rnd * lambda_param + mean
@@ -177,7 +176,6 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
seed1, seed2 = get_seed(seed, "uniform") seed1, seed2 = get_seed(seed, "uniform")
const_utils.check_int_non_negative("seed", seed2, "uniform")
if const_utils.is_same_type(dtype, mstype.int32): if const_utils.is_same_type(dtype, mstype.int32):
random_uniform = P.UniformInt(seed1, seed2) random_uniform = P.UniformInt(seed1, seed2)
value = random_uniform(shape, minval, maxval) value = random_uniform(shape, minval, maxval)
@@ -210,7 +208,6 @@ def gamma(shape, alpha, beta, seed=None):
>>> output = C.gamma(shape, alpha, beta, seed=5) >>> output = C.gamma(shape, alpha, beta, seed=5)
""" """
seed1, seed2 = get_seed(seed, "gamma") seed1, seed2 = get_seed(seed, "gamma")
const_utils.check_int_non_negative("seed", seed2, "gamma")
random_gamma = P.Gamma(seed1, seed2) random_gamma = P.Gamma(seed1, seed2)
value = random_gamma(shape, alpha, beta) value = random_gamma(shape, alpha, beta)
return value return value
@@ -235,7 +232,6 @@ def poisson(shape, mean, seed=None):
>>> output = C.poisson(shape, mean, seed=5) >>> output = C.poisson(shape, mean, seed=5)
""" """
seed1, seed2 = get_seed(seed, "poisson") seed1, seed2 = get_seed(seed, "poisson")
const_utils.check_int_non_negative("seed", seed2, "poisson")
random_poisson = P.Poisson(seed1, seed2) random_poisson = P.Poisson(seed1, seed2)
value = random_poisson(shape, mean) value = random_poisson(shape, mean)
return value return value


Loading…
Cancel
Save