Browse Source

!1455 Cleanup dataset UT: restore config support

Merge pull request !1455 from cathwong/ckw_dataset_config_util
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0a8ef2fe72
4 changed files with 106 additions and 4 deletions
  1. +70
    -0
      tests/ut/python/dataset/test_config.py
  2. +5
    -1
      tests/ut/python/dataset/test_datasets_textfileop.py
  3. +6
    -1
      tests/ut/python/dataset/test_split.py
  4. +25
    -2
      tests/ut/python/dataset/util.py

+ 70
- 0
tests/ut/python/dataset/test_config.py View File

@@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_basic():
"""
Test basic configuration functions
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
prefetch_size_original = ds.config.get_prefetch_size()
seed_original = ds.config.get_seed()

ds.config.load('../data/dataset/declient.cfg')

# assert ds.config.get_rows_per_buffer() == 32
@@ -50,6 +58,11 @@ def test_basic():
assert ds.config.get_prefetch_size() == 4
assert ds.config.get_seed() == 5

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_prefetch_size(prefetch_size_original)
ds.config.set_seed(seed_original)


def test_get_seed():
"""
@@ -62,6 +75,9 @@ def test_pipeline():
"""
Test that our configuration pipeline works when we set parameters at different locations in dataset code
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()

data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
@@ -85,6 +101,9 @@ def test_pipeline():
except IOError:
logger.info("Error while deleting: {}".format(f))

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)


def test_deterministic_run_fail():
"""
@@ -92,6 +111,10 @@ def test_deterministic_run_fail():
"""
logger.info("test_deterministic_run_fail")

# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()

# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@@ -120,12 +143,21 @@ def test_deterministic_run_fail():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)


def test_deterministic_run_pass():
"""
Test deterministic run with with setting the seed
"""
logger.info("test_deterministic_run_pass")

# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)

@@ -152,13 +184,23 @@ def test_deterministic_run_pass():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)


def test_seed_undeterministic():
"""
Test seed with num parallel workers in c, this test is expected to fail some of the time
"""
logger.info("test_seed_undeterministic")

# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()

ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)

# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@@ -178,6 +220,10 @@ def test_seed_undeterministic():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)


def test_deterministic_run_distribution():
"""
@@ -185,6 +231,10 @@ def test_deterministic_run_distribution():
"""
logger.info("test_deterministic_run_distribution")

# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()

# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
@@ -206,12 +256,21 @@ def test_deterministic_run_distribution():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)


def test_deterministic_python_seed():
"""
Test deterministic execution with seed in python
"""
logger.info("deterministic_random_crop_op_python_2")

# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()

ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)

@@ -242,12 +301,20 @@ def test_deterministic_python_seed():

np.testing.assert_equal(data1_output, data2_output)

# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)


def test_deterministic_python_seed_multi_thread():
"""
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
"""
logger.info("deterministic_random_crop_op_python_2")

# Save original configuration values
seed_original = ds.config.get_seed()

ds.config.set_seed(0)
# when we set the seed all operations within our dataset should be deterministic
# First dataset
@@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)

# Restore original configuration values
ds.config.set_seed(seed_original)


if __name__ == '__main__':
test_basic()


+ 5
- 1
tests/ut/python/dataset/test_datasets_textfileop.py View File

@@ -14,6 +14,8 @@
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers


DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
@@ -38,7 +40,7 @@ def test_textline_dataset_all_file():


def test_textline_dataset_totext():
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.",
@@ -48,6 +50,8 @@ def test_textline_dataset_totext():
assert (str == line[count])
count += 1
assert (count == 5)
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)


def test_textline_dataset_num_samples():


+ 6
- 1
tests/ut/python/dataset/test_split.py View File

@@ -14,6 +14,8 @@
# ==============================================================================
import pytest
import mindspore.dataset as ds
from util import config_get_set_num_parallel_workers


# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
@@ -80,7 +82,7 @@ def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)

@@ -122,6 +124,9 @@ def test_unmappable_split():

assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)


def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file)


+ 25
- 2
tests/ut/python/dataset/util.py View File

@@ -15,11 +15,11 @@

import hashlib
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import os

# import jsbeautifier
import mindspore.dataset as ds
from mindspore import log as logger

# These are the column names defined in the testTFTestAllTypes dataset
@@ -221,3 +221,26 @@ def visualize(image_original, image_transformed):
plt.title("Transformed image")

plt.show()


def config_get_set_seed(seed_new):
"""
Get and return the original configuration seed value.
Set the new configuration seed value.
"""
seed_original = ds.config.get_seed()
ds.config.set_seed(seed_new)
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
return seed_original


def config_get_set_num_parallel_workers(num_parallel_workers_new):
"""
Get and return the original configuration num_parallel_workers value.
Set the new configuration num_parallel_workers value.
"""
num_parallel_workers_original = ds.config.get_num_parallel_workers()
ds.config.set_num_parallel_workers(num_parallel_workers_new)
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
num_parallel_workers_new))
return num_parallel_workers_original

Loading…
Cancel
Save