Browse Source

!2766 get default value if num_parallel_workers is None

Merge pull request !2766 from yanghaitao/yht_serialize_num_parallel_worker
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9377e432d2
2 changed files with 5 additions and 4 deletions
  1. +3
    -1
      mindspore/dataset/engine/serializer_deserializer.py
  2. +2
    -3
      tests/ut/python/dataset/test_config.py

+ 3
- 1
mindspore/dataset/engine/serializer_deserializer.py View File

@@ -22,7 +22,7 @@ import sys
from mindspore import log as logger
from . import datasets as de
from ..transforms.vision.utils import Inter, Border
from ..core.configuration import config

def serialize(dataset, json_filepath=None):
"""
@@ -164,6 +164,8 @@ def traverse(node):
node_repr[k] = v.to_json()
elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
expand_path(node_repr, k, v)
elif k == "num_parallel_workers" and v is None:
node_repr[k] = config.get_num_parallel_workers()
else:
node_repr[k] = v



+ 2
- 3
tests/ut/python/dataset/test_config.py View File

@@ -84,12 +84,11 @@ def test_pipeline():
num_parallel_workers_original = ds.config.get_num_parallel_workers()

data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[c_vision.Decode(True)])
ds.serialize(data1, "testpipeline.json")

data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(4)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original,
shuffle=False)
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode(True)])
ds.serialize(data2, "testpipeline2.json")



Loading…
Cancel
Save