Added testcase to show that c image aug don't use seed properly Added passing test cases Added working testcases for using seed Added additional test cases to show seed use Added test case for seedtags/v0.3.0-alpha
| @@ -15,7 +15,7 @@ | |||||
| """ | """ | ||||
| The configuration manager. | The configuration manager. | ||||
| """ | """ | ||||
| import random | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| INT32_MAX = 2147483647 | INT32_MAX = 2147483647 | ||||
| @@ -32,6 +32,12 @@ class ConfigurationManager: | |||||
| """ | """ | ||||
| Set the seed to be used in any random generator. This is used to produce deterministic results. | Set the seed to be used in any random generator. This is used to produce deterministic results. | ||||
| Note: | |||||
| This set_seed function sets the seed in the python random library function for deterministic | |||||
| python augmentations using randomness. This set_seed function should be called with every | |||||
| iterator created to reset the random seed. In our pipeline this does not guarantee | |||||
| deterministic results with num_parallel_workers > 1. | |||||
| Args: | Args: | ||||
| seed(int): seed to be set | seed(int): seed to be set | ||||
| @@ -47,6 +53,7 @@ class ConfigurationManager: | |||||
| if seed < 0 or seed > UINT32_MAX: | if seed < 0 or seed > UINT32_MAX: | ||||
| raise ValueError("Seed given is not within the required range") | raise ValueError("Seed given is not within the required range") | ||||
| self.config.set_seed(seed) | self.config.set_seed(seed) | ||||
| random.seed(seed) | |||||
| def get_seed(self): | def get_seed(self): | ||||
| """ | """ | ||||
| @@ -13,14 +13,19 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """ | """ | ||||
| Testing configuration manager | |||||
| Testing configuration manager | |||||
| """ | """ | ||||
| import filecmp | import filecmp | ||||
| import glob | import glob | ||||
| import numpy as np | |||||
| import os | import os | ||||
| from mindspore import log as logger | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | import mindspore.dataset.transforms.vision.c_transforms as vision | ||||
| import mindspore.dataset.transforms.vision.py_transforms as py_vision | |||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | ||||
| @@ -46,9 +51,17 @@ def test_basic(): | |||||
| assert ds.config.get_prefetch_size() == 4 | assert ds.config.get_prefetch_size() == 4 | ||||
| assert ds.config.get_seed() == 5 | assert ds.config.get_seed() == 5 | ||||
| def test_get_seed(): | |||||
| """ | |||||
| This gets the seed value without explicitly setting a default, expect int. | |||||
| """ | |||||
| assert isinstance(ds.config.get_seed(), int) | |||||
| def test_pipeline(): | def test_pipeline(): | ||||
| """ | |||||
| Test that our configuration pipeline works when we set parameters at dataset interval | |||||
| """ | |||||
| Test that our configuration pipeline works when we set parameters at different locations in dataset code | |||||
| """ | """ | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| ds.config.set_num_parallel_workers(2) | ds.config.set_num_parallel_workers(2) | ||||
| @@ -60,12 +73,12 @@ def test_pipeline(): | |||||
| data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) | data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) | ||||
| ds.serialize(data2, "testpipeline2.json") | ds.serialize(data2, "testpipeline2.json") | ||||
| # check that the generated output is different | |||||
| # check that the generated output is different | |||||
| assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) | assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) | ||||
| # this test passes currently because our num_parallel_workers don't get updated. | |||||
| # this test passes currently because our num_parallel_workers don't get updated. | |||||
| # remove generated jason files | |||||
| # remove generated jason files | |||||
| file_list = glob.glob('*.json') | file_list = glob.glob('*.json') | ||||
| for f in file_list: | for f in file_list: | ||||
| try: | try: | ||||
| @@ -74,6 +87,209 @@ def test_pipeline(): | |||||
| logger.info("Error while deleting: {}".format(f)) | logger.info("Error while deleting: {}".format(f)) | ||||
| def test_deterministic_run_fail(): | |||||
| """ | |||||
| Test RandomCrop with seed, expected to fail | |||||
| """ | |||||
| logger.info("test_deterministic_run_fail") | |||||
| # when we set the seed all operations within our dataset should be deterministic | |||||
| ds.config.set_seed(0) | |||||
| ds.config.set_num_parallel_workers(1) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| # Assuming we get the same seed on calling constructor, if this op is re-used then result won't be | |||||
| # the same in between the two datasets. For example, RandomCrop constructor takes seed (0) | |||||
| # outputs a deterministic series of numbers, e,g "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random | |||||
| random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||||
| decode_op = vision.Decode() | |||||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||||
| data1 = data1.map(input_columns=["image"], operations=random_crop_op) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(input_columns=["image"], operations=decode_op) | |||||
| # If seed is set up on constructor | |||||
| data2 = data2.map(input_columns=["image"], operations=random_crop_op) | |||||
| try: | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| np.testing.assert_equal (item1["image"], item2["image"]) | |||||
| except BaseException as e: | |||||
| # two datasets split the number out of the sequence a | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Array" in str(e) | |||||
| def test_deterministic_run_pass(): | |||||
| """ | |||||
| Test deterministic run with with setting the seed | |||||
| """ | |||||
| logger.info("test_deterministic_run_pass") | |||||
| ds.config.set_seed(0) | |||||
| ds.config.set_num_parallel_workers(1) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| # We get the seed when constructor is called | |||||
| random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||||
| decode_op = vision.Decode() | |||||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||||
| data1 = data1.map(input_columns=["image"], operations=random_crop_op) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(input_columns=["image"], operations=decode_op) | |||||
| # Since seed is set up on constructor, so the two ops output deterministic sequence. | |||||
| # Assume the generated random sequence "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random | |||||
| random_crop_op2 = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||||
| data2 = data2.map(input_columns=["image"], operations=random_crop_op2) | |||||
| try: | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| np.testing.assert_equal (item1["image"], item2["image"]) | |||||
| except BaseException as e: | |||||
| # two datasets both use numbers from the generated sequence "a" | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Array" in str(e) | |||||
| def test_seed_undeterministic(): | |||||
| """ | |||||
| Test seed with num parallel workers in c, this test is expected to fail some of the time | |||||
| """ | |||||
| logger.info("test_seed_undeterministic") | |||||
| ds.config.set_seed(0) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| # seed will be read in during constructor call | |||||
| random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||||
| decode_op = vision.Decode() | |||||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||||
| data1 = data1.map(input_columns=["image"], operations=random_crop_op) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(input_columns=["image"], operations=decode_op) | |||||
| # If seed is set up on constructor, so the two ops output deterministic sequence | |||||
| random_crop_op2 = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||||
| data2 = data2.map(input_columns=["image"], operations=random_crop_op2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| np.testing.assert_equal (item1["image"], item2["image"]) | |||||
| def test_deterministic_run_distribution(): | |||||
| """ | |||||
| Test deterministic run with with setting the seed being used in a distribution | |||||
| """ | |||||
| logger.info("test_deterministic_run_distribution") | |||||
| # when we set the seed all operations within our dataset should be deterministic | |||||
| ds.config.set_seed(0) | |||||
| ds.config.set_num_parallel_workers(1) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| random_crop_op = vision.RandomHorizontalFlip(0.1) | |||||
| decode_op = vision.Decode() | |||||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||||
| data1 = data1.map(input_columns=["image"], operations=random_crop_op) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(input_columns=["image"], operations=decode_op) | |||||
| # If seed is set up on constructor, so the two ops output deterministic sequence | |||||
| random_crop_op2 = vision.RandomHorizontalFlip(0.1) | |||||
| data2 = data2.map(input_columns=["image"], operations=random_crop_op2) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| np.testing.assert_equal (item1["image"], item2["image"]) | |||||
| def test_deterministic_python_seed(): | |||||
| """ | |||||
| Test deterministic execution with seed in python | |||||
| """ | |||||
| logger.info("deterministic_random_crop_op_python_2") | |||||
| ds.config.set_seed(0) | |||||
| ds.config.set_num_parallel_workers(1) | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| transforms = [ | |||||
| py_vision.Decode(), | |||||
| py_vision.RandomCrop([512, 512], [200, 200, 200, 200]), | |||||
| py_vision.ToTensor(), | |||||
| ] | |||||
| transform = py_vision.ComposeOp(transforms) | |||||
| data1 = data1.map(input_columns=["image"], operations=transform()) | |||||
| data1_output = [] | |||||
| # config.set_seed() calls random.seed() | |||||
| for data_one in data1.create_dict_iterator(): | |||||
| data1_output.append(data_one["image"]) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| data2 = data2.map(input_columns=["image"], operations=transform()) | |||||
| # config.set_seed() calls random.seed(), resets seed for next dataset iterator | |||||
| ds.config.set_seed(0) | |||||
| data2_output = [] | |||||
| for data_two in data2.create_dict_iterator(): | |||||
| data2_output.append(data_two["image"]) | |||||
| np.testing.assert_equal (data1_output, data2_output) | |||||
| def test_deterministic_python_seed_multi_thread(): | |||||
| """ | |||||
| Test deterministic execution with seed in python, this fails with multi-thread pyfunc run | |||||
| """ | |||||
| logger.info("deterministic_random_crop_op_python_2") | |||||
| ds.config.set_seed(0) | |||||
| # when we set the seed all operations within our dataset should be deterministic | |||||
| # First dataset | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| transforms = [ | |||||
| py_vision.Decode(), | |||||
| py_vision.RandomCrop([512, 512], [200, 200, 200, 200]), | |||||
| py_vision.ToTensor(), | |||||
| ] | |||||
| transform = py_vision.ComposeOp(transforms) | |||||
| data1 = data1.map(input_columns=["image"], operations=transform(), python_multiprocessing=True) | |||||
| data1_output = [] | |||||
| # config.set_seed() calls random.seed() | |||||
| for data_one in data1.create_dict_iterator(): | |||||
| data1_output.append(data_one["image"]) | |||||
| # Second dataset | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||||
| # If seed is set up on constructor | |||||
| data2 = data2.map(input_columns=["image"], operations=transform(), python_multiprocessing=True) | |||||
| # config.set_seed() calls random.seed() | |||||
| ds.config.set_seed(0) | |||||
| data2_output = [] | |||||
| for data_two in data2.create_dict_iterator(): | |||||
| data2_output.append(data_two["image"]) | |||||
| try: | |||||
| np.testing.assert_equal (data1_output, data2_output) | |||||
| except BaseException as e: | |||||
| # expect output to not match during multi-threaded excution | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Array" in str(e) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_basic() | test_basic() | ||||
| test_pipeline() | test_pipeline() | ||||
| test_deterministic_run_pass() | |||||
| test_deterministic_run_distribution() | |||||
| test_deterministic_run_fail() | |||||
| test_deterministic_python_seed() | |||||
| test_seed_undeterministic() | |||||
| test_get_seed() | |||||
| @@ -36,6 +36,7 @@ def test_textline_dataset_all_file(): | |||||
| assert(count == 5) | assert(count == 5) | ||||
| def test_textline_dataset_totext(): | def test_textline_dataset_totext(): | ||||
| ds.config.set_num_parallel_workers(4) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | ||||
| count = 0 | count = 0 | ||||
| line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] | line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] | ||||
| @@ -37,7 +37,7 @@ def visualize(first, mse, second): | |||||
| plt.subplot(142) | plt.subplot(142) | ||||
| plt.imshow(second) | plt.imshow(second) | ||||
| plt.title("py random_color_jitter image") | |||||
| plt.title("py random_color_adjust image") | |||||
| plt.subplot(143) | plt.subplot(143) | ||||
| plt.imshow(first - second) | plt.imshow(first - second) | ||||
| @@ -50,20 +50,20 @@ def diff_mse(in1, in2): | |||||
| return mse * 100 | return mse * 100 | ||||
| def test_random_color_jitter_op_brightness(): | |||||
| def test_random_color_adjust_op_brightness(): | |||||
| """ | """ | ||||
| Test RandomColorAdjust op | Test RandomColorAdjust op | ||||
| """ | """ | ||||
| logger.info("test_random_color_jitter_op") | |||||
| logger.info("test_random_color_adjust_op") | |||||
| # First dataset | # First dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| random_jitter_op = c_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0)) | |||||
| random_adjust_op = c_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0)) | |||||
| ctrans = [decode_op, | ctrans = [decode_op, | ||||
| random_jitter_op, | |||||
| random_adjust_op, | |||||
| ] | ] | ||||
| data1 = data1.map(input_columns=["image"], operations=ctrans) | data1 = data1.map(input_columns=["image"], operations=ctrans) | ||||
| @@ -100,20 +100,20 @@ def test_random_color_jitter_op_brightness(): | |||||
| # visualize(c_image, mse, py_image) | # visualize(c_image, mse, py_image) | ||||
| def test_random_color_jitter_op_contrast(): | |||||
| def test_random_color_adjust_op_contrast(): | |||||
| """ | """ | ||||
| Test RandomColorAdjust op | Test RandomColorAdjust op | ||||
| """ | """ | ||||
| logger.info("test_random_color_jitter_op") | |||||
| logger.info("test_random_color_adjust_op") | |||||
| # First dataset | # First dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| random_jitter_op = c_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0)) | |||||
| random_adjust_op = c_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0)) | |||||
| ctrans = [decode_op, | ctrans = [decode_op, | ||||
| random_jitter_op | |||||
| random_adjust_op | |||||
| ] | ] | ||||
| data1 = data1.map(input_columns=["image"], operations=ctrans) | data1 = data1.map(input_columns=["image"], operations=ctrans) | ||||
| @@ -156,20 +156,20 @@ def test_random_color_jitter_op_contrast(): | |||||
| # visualize(c_image, mse, py_image) | # visualize(c_image, mse, py_image) | ||||
| def test_random_color_jitter_op_saturation(): | |||||
| def test_random_color_adjust_op_saturation(): | |||||
| """ | """ | ||||
| Test RandomColorAdjust op | Test RandomColorAdjust op | ||||
| """ | """ | ||||
| logger.info("test_random_color_jitter_op") | |||||
| logger.info("test_random_color_adjust_op") | |||||
| # First dataset | # First dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0)) | |||||
| random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0)) | |||||
| ctrans = [decode_op, | ctrans = [decode_op, | ||||
| random_jitter_op | |||||
| random_adjust_op | |||||
| ] | ] | ||||
| data1 = data1.map(input_columns=["image"], operations=ctrans) | data1 = data1.map(input_columns=["image"], operations=ctrans) | ||||
| @@ -209,20 +209,20 @@ def test_random_color_jitter_op_saturation(): | |||||
| # visualize(c_image, mse, py_image) | # visualize(c_image, mse, py_image) | ||||
| def test_random_color_jitter_op_hue(): | |||||
| def test_random_color_adjust_op_hue(): | |||||
| """ | """ | ||||
| Test RandomColorAdjust op | Test RandomColorAdjust op | ||||
| """ | """ | ||||
| logger.info("test_random_color_jitter_op") | |||||
| logger.info("test_random_color_adjust_op") | |||||
| # First dataset | # First dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| decode_op = c_vision.Decode() | decode_op = c_vision.Decode() | ||||
| random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)) | |||||
| random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)) | |||||
| ctrans = [decode_op, | ctrans = [decode_op, | ||||
| random_jitter_op, | |||||
| random_adjust_op, | |||||
| ] | ] | ||||
| data1 = data1.map(input_columns=["image"], operations=ctrans) | data1 = data1.map(input_columns=["image"], operations=ctrans) | ||||
| @@ -264,7 +264,7 @@ def test_random_color_jitter_op_hue(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_color_jitter_op_brightness() | |||||
| test_random_color_jitter_op_contrast() | |||||
| test_random_color_jitter_op_saturation() | |||||
| test_random_color_jitter_op_hue() | |||||
| test_random_color_adjust_op_brightness() | |||||
| test_random_color_adjust_op_contrast() | |||||
| test_random_color_adjust_op_saturation() | |||||
| test_random_color_adjust_op_hue() | |||||
| @@ -17,8 +17,8 @@ Testing RandomCropAndResize op in DE | |||||
| """ | """ | ||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | import mindspore.dataset.transforms.vision.c_transforms as vision | ||||
| from mindspore import log as logger | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| @@ -45,9 +45,9 @@ def visualize(a, mse, original): | |||||
| def test_random_crop_op(): | def test_random_crop_op(): | ||||
| """ | """ | ||||
| Test RandomCropAndResize op | |||||
| Test RandomCrop Op | |||||
| """ | """ | ||||
| logger.info("test_random_crop_and_resize_op") | |||||
| logger.info("test_random_crop_op") | |||||
| # First dataset | # First dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| @@ -67,3 +67,4 @@ def test_random_crop_op(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_crop_op() | test_random_crop_op() | ||||
| @@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -34,9 +35,9 @@ def test_rename(): | |||||
| for i, item in enumerate(data.create_dict_iterator()): | for i, item in enumerate(data.create_dict_iterator()): | ||||
| logger.info("item[mask] is {}".format(item["masks"])) | logger.info("item[mask] is {}".format(item["masks"])) | ||||
| assert item["masks"].all() == item["input_ids"].all() | |||||
| np.testing.assert_equal (item["masks"], item["input_ids"]) | |||||
| logger.info("item[seg_ids] is {}".format(item["seg_ids"])) | logger.info("item[seg_ids] is {}".format(item["seg_ids"])) | ||||
| assert item["segment_ids"].all() == item["seg_ids"].all() | |||||
| np.testing.assert_equal (item["segment_ids"], item["seg_ids"]) | |||||
| # need to consume the data in the buffer | # need to consume the data in the buffer | ||||
| num_iter += 1 | num_iter += 1 | ||||
| logger.info("Number of data in data: {}".format(num_iter)) | logger.info("Number of data in data: {}".format(num_iter)) | ||||
| @@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | |||||
| from util import save_and_check | from util import save_and_check | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| @@ -117,6 +118,27 @@ def test_shuffle_05(): | |||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | ||||
| def test_shuffle_06(): | |||||
| """ | |||||
| Test shuffle: with set seed, both datasets | |||||
| """ | |||||
| logger.info("test_shuffle_06") | |||||
| # define parameters | |||||
| buffer_size = 13 | |||||
| seed = 1 | |||||
| # apply dataset operations | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | |||||
| ds.config.set_seed(seed) | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | |||||
| data2 = data2.shuffle(buffer_size=buffer_size) | |||||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||||
| np.testing.assert_equal (item1, item2) | |||||
| def test_shuffle_exception_01(): | def test_shuffle_exception_01(): | ||||
| """ | """ | ||||
| Test shuffle exception: buffer_size<0 | Test shuffle exception: buffer_size<0 | ||||
| @@ -231,6 +253,7 @@ if __name__ == '__main__': | |||||
| test_shuffle_03() | test_shuffle_03() | ||||
| test_shuffle_04() | test_shuffle_04() | ||||
| test_shuffle_05() | test_shuffle_05() | ||||
| test_shuffle_06() | |||||
| test_shuffle_exception_01() | test_shuffle_exception_01() | ||||
| test_shuffle_exception_02() | test_shuffle_exception_02() | ||||
| test_shuffle_exception_03() | test_shuffle_exception_03() | ||||