|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """
- Test TFRecordDataset Ops
- """
- import numpy as np
- import pytest
-
- import mindspore.common.dtype as mstype
- import mindspore.dataset as ds
- from mindspore import log as logger
- from util import save_and_check_dict
-
- FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
- DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
- SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
- DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
- "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
- "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
- "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
- SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
- GENERATE_GOLDEN = False
-
-
- def test_tfrecord_shape():
- logger.info("test_tfrecord_shape")
- schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
- ds1 = ds.TFRecordDataset(FILES, schema_file)
- ds1 = ds1.batch(2)
- for data in ds1.create_dict_iterator():
- logger.info(data)
- output_shape = ds1.output_shapes()
- assert len(output_shape[-1]) == 1
-
-
- def test_tfrecord_read_all_dataset():
- logger.info("test_tfrecord_read_all_dataset")
- schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
- ds1 = ds.TFRecordDataset(FILES, schema_file)
- assert ds1.get_dataset_size() == 12
- count = 0
- for _ in ds1.create_tuple_iterator():
- count += 1
- assert count == 12
-
-
- def test_tfrecord_num_samples():
- logger.info("test_tfrecord_num_samples")
- schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
- ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
- assert ds1.get_dataset_size() == 8
- count = 0
- for _ in ds1.create_dict_iterator():
- count += 1
- assert count == 8
-
-
- def test_tfrecord_num_samples2():
- logger.info("test_tfrecord_num_samples2")
- schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
- ds1 = ds.TFRecordDataset(FILES, schema_file)
- assert ds1.get_dataset_size() == 7
- count = 0
- for _ in ds1.create_dict_iterator():
- count += 1
- assert count == 7
-
-
- def test_tfrecord_shape2():
- logger.info("test_tfrecord_shape2")
- ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
- ds1 = ds1.batch(2)
- output_shape = ds1.output_shapes()
- assert len(output_shape[-1]) == 2
-
-
- def test_tfrecord_files_basic():
- logger.info("test_tfrecord_files_basic")
-
- data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
- filename = "tfrecord_files_basic.npz"
- save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
-
-
- def test_tfrecord_no_schema():
- logger.info("test_tfrecord_no_schema")
-
- data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
- filename = "tfrecord_no_schema.npz"
- save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
-
-
- def test_tfrecord_pad():
- logger.info("test_tfrecord_pad")
-
- schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
- data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
- filename = "tfrecord_pad_bytes10.npz"
- save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
-
-
- def test_tfrecord_read_files():
- logger.info("test_tfrecord_read_files")
- pattern = DATASET_ROOT + "/test.data"
- data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
- assert sum([1 for _ in data]) == 12
-
- pattern = DATASET_ROOT + "/test2.data"
- data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
- assert sum([1 for _ in data]) == 12
-
- pattern = DATASET_ROOT + "/*.data"
- data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
- assert sum([1 for _ in data]) == 24
-
- pattern = DATASET_ROOT + "/*.data"
- data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES)
- assert sum([1 for _ in data]) == 3
-
- data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"],
- SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
- assert sum([1 for _ in data]) == 24
-
-
- def test_tfrecord_multi_files():
- logger.info("test_tfrecord_multi_files")
- data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
- data1 = data1.repeat(1)
- num_iter = 0
- for _ in data1.create_dict_iterator():
- num_iter += 1
-
- assert num_iter == 12
-
-
- def test_tfrecord_schema():
- logger.info("test_tfrecord_schema")
- schema = ds.Schema()
- schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
- schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
- schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
- schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
- schema.add_column('col_float', de_type=mstype.float32, shape=[1])
- schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
- schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
- schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
- data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
-
- data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
-
- for d1, d2 in zip(data1, data2):
- for t1, t2 in zip(d1, d2):
- assert np.array_equal(t1, t2)
-
-
- def test_tfrecord_shuffle():
- logger.info("test_tfrecord_shuffle")
- ds.config.set_seed(1)
- data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
- data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
- data2 = data2.shuffle(10000)
-
- for d1, d2 in zip(data1, data2):
- for t1, t2 in zip(d1, d2):
- assert np.array_equal(t1, t2)
-
-
- def test_tfrecord_shard():
- logger.info("test_tfrecord_shard")
- tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
- "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
-
- def get_res(shard_id, num_repeats):
- data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3,
- shuffle=ds.Shuffle.FILES)
- data1 = data1.repeat(num_repeats)
- res = list()
- for item in data1.create_dict_iterator():
- res.append(item["scalars"][0])
- return res
-
- # get separate results from two workers. the 2 results need to satisfy 2 criteria
- # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch)
- # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
- worker1_res = get_res(0, 16)
- worker2_res = get_res(1, 16)
- # Confirm each worker gets 3x16=48 rows
- assert len(worker1_res) == 48
- assert len(worker1_res) == len(worker2_res)
- # check criteria 1
- for i, _ in enumerate(worker1_res):
- assert worker1_res[i] != worker2_res[i]
- # check criteria 2
- assert set(worker2_res) == set(worker1_res)
-
-
- def test_tfrecord_shard_equal_rows():
- logger.info("test_tfrecord_shard_equal_rows")
- tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
- "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
-
- def get_res(num_shards, shard_id, num_repeats):
- ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True)
- ds1 = ds1.repeat(num_repeats)
- res = list()
- for data in ds1.create_dict_iterator():
- res.append(data["scalars"][0])
- return res
-
- worker1_res = get_res(3, 0, 2)
- worker2_res = get_res(3, 1, 2)
- worker3_res = get_res(3, 2, 2)
- # check criteria 1
- for i, _ in enumerate(worker1_res):
- assert worker1_res[i] != worker2_res[i]
- assert worker2_res[i] != worker3_res[i]
- # Confirm each worker gets same number of rows
- assert len(worker1_res) == 28
- assert len(worker1_res) == len(worker2_res)
- assert len(worker2_res) == len(worker3_res)
-
- worker4_res = get_res(1, 0, 1)
- assert len(worker4_res) == 40
-
-
- def test_tfrecord_no_schema_columns_list():
- logger.info("test_tfrecord_no_schema_columns_list")
- data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
- row = data.create_dict_iterator().get_next()
- assert row["col_sint16"] == [-32768]
-
- with pytest.raises(KeyError) as info:
- _ = row["col_sint32"]
- assert "col_sint32" in str(info.value)
-
-
- def test_tfrecord_schema_columns_list():
- logger.info("test_tfrecord_schema_columns_list")
- schema = ds.Schema()
- schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
- schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
- schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
- schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
- schema.add_column('col_float', de_type=mstype.float32, shape=[1])
- schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
- schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
- schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
- data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
- row = data.create_dict_iterator().get_next()
- assert row["col_sint16"] == [-32768]
-
- with pytest.raises(KeyError) as info:
- _ = row["col_sint32"]
- assert "col_sint32" in str(info.value)
-
-
- def test_tfrecord_invalid_files():
- logger.info("test_tfrecord_invalid_files")
- valid_file = "../data/dataset/testTFTestAllTypes/test.data"
- invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
- files = [invalid_file, valid_file, SCHEMA_FILE]
-
- data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
-
- with pytest.raises(RuntimeError) as info:
- _ = data.create_dict_iterator().get_next()
- assert "cannot be opened" in str(info.value)
- assert "not valid tfrecord files" in str(info.value)
- assert valid_file not in str(info.value)
- assert invalid_file in str(info.value)
- assert SCHEMA_FILE in str(info.value)
-
- nonexistent_file = "this/file/does/not/exist"
- files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
-
- with pytest.raises(ValueError) as info:
- data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
- assert "did not match any files" in str(info.value)
- assert valid_file not in str(info.value)
- assert invalid_file not in str(info.value)
- assert SCHEMA_FILE not in str(info.value)
- assert nonexistent_file in str(info.value)
-
-
- if __name__ == '__main__':
- test_tfrecord_shape()
- test_tfrecord_read_all_dataset()
- test_tfrecord_num_samples()
- test_tfrecord_num_samples2()
- test_tfrecord_shape2()
- test_tfrecord_files_basic()
- test_tfrecord_no_schema()
- test_tfrecord_pad()
- test_tfrecord_read_files()
- test_tfrecord_multi_files()
- test_tfrecord_schema()
- test_tfrecord_shuffle()
- test_tfrecord_shard()
- test_tfrecord_shard_equal_rows()
- test_tfrecord_no_schema_columns_list()
- test_tfrecord_schema_columns_list()
- test_tfrecord_invalid_files()
|