Browse Source

!6113 Unify minddata seed to set_seed

Merge pull request !6113 from xiefangqi/md_unify_seed
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
10aec24510
4 changed files with 5 additions and 7 deletions
  1. +2
    -0
      mindspore/common/seed.py
  2. +3
    -5
      mindspore/dataset/engine/validators.py
  3. +0
    -1
      model_zoo/official/cv/googlenet/src/dataset.py
  4. +0
    -1
      model_zoo/official/cv/vgg16/src/dataset.py

+ 2
- 0
mindspore/common/seed.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""Provide random seed api."""
import numpy as np
import mindspore.dataset as de

# set global RNG seed
_GLOBAL_SEED = None
@@ -43,6 +44,7 @@ def set_seed(seed):
if seed < 0:
raise ValueError("The seed must be greater or equal to 0.")
np.random.seed(seed)
de.config.set_seed(seed)
global _GLOBAL_SEED
_GLOBAL_SEED = seed



+ 3
- 5
mindspore/dataset/engine/validators.py View File

@@ -23,7 +23,6 @@ from functools import wraps

import numpy as np
from mindspore._c_expression import typing
from mindspore.dataset.callback import DSCallback
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
@@ -31,8 +30,6 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis

from . import datasets
from . import samplers
# from . import cache_client
from .. import callback


def check_imagefolderdataset(method):
@@ -566,6 +563,7 @@ def check_map(method):

@wraps(method)
def new_method(self, *args, **kwargs):
from mindspore.dataset.callback import DSCallback
[_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
callbacks], _ = \
parse_user_args(method, *args, **kwargs)
@@ -581,9 +579,9 @@ def check_map(method):

if callbacks is not None:
if isinstance(callbacks, (list, tuple)):
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
type_check_list(callbacks, (DSCallback,), "callbacks")
else:
type_check(callbacks, (callback.DSCallback,), "callbacks")
type_check(callbacks, (DSCallback,), "callbacks")

for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
if param is not None:


+ 0
- 1
model_zoo/official/cv/googlenet/src/dataset.py View File

@@ -26,7 +26,6 @@ from src.config import cifar_cfg, imagenet_cfg

def create_dataset_cifar10(data_home, repeat_num=1, training=True):
"""Data operations."""
ds.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")


+ 0
- 1
model_zoo/official/cv/vgg16/src/dataset.py View File

@@ -28,7 +28,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True

def vgg_create_dataset(data_home, image_size, batch_size, rank_id=0, rank_size=1, repeat_num=1, training=True):
"""Data operations."""
de.config.set_seed(1)
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")


Loading…
Cancel
Save