Merge pull request !3067 from cathwong/ckw_dataset_ut_cleanup6tags/v0.6.0-beta
| @@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) { | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| // Creating TFReaderOp | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| rc = TFReaderOp::Builder() | |||
| .SetDatasetFilesList({dataset_path}) | |||
| @@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) { | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| // Creating TFReaderOp | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; | |||
| std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| rc = TFReaderOp::Builder() | |||
| @@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { | |||
| MS_LOG(INFO) << "UT test TestZipRepeat."; | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; | |||
| std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; | |||
| std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| rc = TFReaderOp::Builder() | |||
| @@ -1,11 +0,0 @@ | |||
| { | |||
| "datasetType": "TF", | |||
| "numRows": 3, | |||
| "columns": { | |||
| "label": { | |||
| "type": "int64", | |||
| "rank": 1, | |||
| "t_impl": "flex" | |||
| } | |||
| } | |||
| } | |||
| @@ -1,11 +0,0 @@ | |||
| { | |||
| "datasetType": "TF", | |||
| "numRows": 3, | |||
| "columns": { | |||
| "image": { | |||
| "type": "uint8", | |||
| "rank": 1, | |||
| "t_impl": "cvmat" | |||
| } | |||
| } | |||
| } | |||
| @@ -1,204 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_case_repeat(): | |||
| """ | |||
| a simple repeat operation. | |||
| """ | |||
| logger.info("Test Simple Repeat") | |||
| # define parameters | |||
| repeat_count = 2 | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| def test_case_shuffle(): | |||
| """ | |||
| a simple shuffle operation. | |||
| """ | |||
| logger.info("Test Simple Shuffle") | |||
| # define parameters | |||
| buffer_size = 8 | |||
| seed = 10 | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| ds.config.set_seed(seed) | |||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| def test_case_0(): | |||
| """ | |||
| Test Repeat then Shuffle | |||
| """ | |||
| logger.info("Test Repeat then Shuffle") | |||
| # define parameters | |||
| repeat_count = 2 | |||
| buffer_size = 7 | |||
| seed = 9 | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| ds.config.set_seed(seed) | |||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| def test_case_0_reverse(): | |||
| """ | |||
| Test Shuffle then Repeat | |||
| """ | |||
| logger.info("Test Shuffle then Repeat") | |||
| # define parameters | |||
| repeat_count = 2 | |||
| buffer_size = 10 | |||
| seed = 9 | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| ds.config.set_seed(seed) | |||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| def test_case_3(): | |||
| """ | |||
| Test Map | |||
| """ | |||
| logger.info("Test Map Rescale and Resize, then Shuffle") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| # define data augmentation parameters | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| resize_height, resize_width = 224, 224 | |||
| # define map operations | |||
| decode_op = vision.Decode() | |||
| rescale_op = vision.Rescale(rescale, shift) | |||
| # resize_op = vision.Resize(resize_height, resize_width, | |||
| # InterpolationMode.DE_INTER_LINEAR) # Bilinear mode | |||
| resize_op = vision.Resize((resize_height, resize_width)) | |||
| # apply map operations on images | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=rescale_op) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| # # apply ont-hot encoding on labels | |||
| num_classes = 4 | |||
| one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_encode) | |||
| # | |||
| # # apply Datasets | |||
| buffer_size = 100 | |||
| seed = 10 | |||
| batch_size = 2 | |||
| ds.config.set_seed(seed) | |||
| data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script | |||
| data1 = data1.batch(batch_size, drop_remainder=True) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| # in this example, each dictionary has keys "image" and "label" | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| if __name__ == '__main__': | |||
| logger.info('===========now test Repeat============') | |||
| # logger.info('Simple Repeat') | |||
| test_case_repeat() | |||
| logger.info('\n') | |||
| logger.info('===========now test Shuffle===========') | |||
| # logger.info('Simple Shuffle') | |||
| test_case_shuffle() | |||
| logger.info('\n') | |||
| # Note: cannot work with different shapes, hence not for image | |||
| # logger.info('===========now test Batch=============') | |||
| # # logger.info('Simple Batch') | |||
| # test_case_batch() | |||
| # logger.info('\n') | |||
| logger.info('===========now test case 0============') | |||
| # logger.info('Repeat then Shuffle') | |||
| test_case_0() | |||
| logger.info('\n') | |||
| logger.info('===========now test case 0 reverse============') | |||
| # # logger.info('Shuffle then Repeat') | |||
| test_case_0_reverse() | |||
| logger.info('\n') | |||
| # logger.info('===========now test case 1============') | |||
| # # logger.info('Repeat with Batch') | |||
| # test_case_1() | |||
| # logger.info('\n') | |||
| # logger.info('===========now test case 2============') | |||
| # # logger.info('Batch with Shuffle') | |||
| # test_case_2() | |||
| # logger.info('\n') | |||
| # for image augmentation only | |||
| logger.info('===========now test case 3============') | |||
| logger.info('Map then Shuffle') | |||
| test_case_3() | |||
| logger.info('\n') | |||
| @@ -1,40 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||
| def test_tf_file_normal(): | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| data1 = data1.repeat(1) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): # each data is a dictionary | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 12 | |||
| if __name__ == '__main__': | |||
| logger.info('=======test normal=======') | |||
| test_tf_file_normal() | |||
| @@ -13,12 +13,13 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the one_hot op in DE | |||
| Testing the OneHot Op | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as data_trans | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import diff_mse | |||
| @@ -37,15 +38,15 @@ def one_hot(index, depth): | |||
| def test_one_hot(): | |||
| """ | |||
| Test one_hot | |||
| Test OneHot Tensor Operator | |||
| """ | |||
| logger.info("Test one_hot") | |||
| logger.info("test_one_hot") | |||
| depth = 10 | |||
| # First dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| one_hot_op = data_trans.OneHot(depth) | |||
| one_hot_op = data_trans.OneHot(num_classes=depth) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"]) | |||
| # Second dataset | |||
| @@ -58,8 +59,54 @@ def test_one_hot(): | |||
| label2 = one_hot(item2["label"][0], depth) | |||
| mse = diff_mse(label1, label2) | |||
| logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse)) | |||
| assert mse == 0 | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| def test_one_hot_post_aug(): | |||
| """ | |||
| Test One Hot Encoding after Multiple Data Augmentation Operators | |||
| """ | |||
| logger.info("test_one_hot_post_aug") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||
| # Define data augmentation parameters | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| resize_height, resize_width = 224, 224 | |||
| # Define map operations | |||
| decode_op = c_vision.Decode() | |||
| rescale_op = c_vision.Rescale(rescale, shift) | |||
| resize_op = c_vision.Resize((resize_height, resize_width)) | |||
| # Apply map operations on images | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=rescale_op) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| # Apply one-hot encoding on labels | |||
| depth = 4 | |||
| one_hot_encode = data_trans.OneHot(depth) | |||
| data1 = data1.map(input_columns=["label"], operations=one_hot_encode) | |||
| # Apply datasets ops | |||
| buffer_size = 100 | |||
| seed = 10 | |||
| batch_size = 2 | |||
| ds.config.set_seed(seed) | |||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||
| data1 = data1.batch(batch_size, drop_remainder=True) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("image is: {}".format(item["image"])) | |||
| logger.info("label is: {}".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 1 | |||
| if __name__ == "__main__": | |||
| test_one_hot() | |||
| test_one_hot_post_aug() | |||
| @@ -12,25 +12,24 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Test Repeat Op | |||
| """ | |||
| import numpy as np | |||
| from util import save_and_check | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| from util import save_and_check_dict | |||
| DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] | |||
| SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | |||
| COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", | |||
| "col_sint16", "col_sint32", "col_sint64"] | |||
| GENERATE_GOLDEN = False | |||
| IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| def test_tf_repeat_01(): | |||
| """ | |||
| @@ -39,14 +38,13 @@ def test_tf_repeat_01(): | |||
| logger.info("Test Simple Repeat") | |||
| # define parameters | |||
| repeat_count = 2 | |||
| parameters = {"params": {'repeat_count': repeat_count}} | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| filename = "repeat_result.npz" | |||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_tf_repeat_02(): | |||
| @@ -99,14 +97,13 @@ def test_tf_repeat_04(): | |||
| logger.info("Test Simple Repeat Column List") | |||
| # define parameters | |||
| repeat_count = 2 | |||
| parameters = {"params": {'repeat_count': repeat_count}} | |||
| columns_list = ["col_sint64", "col_sint32"] | |||
| # apply dataset operations | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| filename = "repeat_list_result.npz" | |||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| def generator(): | |||
| @@ -115,6 +112,7 @@ def generator(): | |||
| def test_nested_repeat1(): | |||
| logger.info("test_nested_repeat1") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| @@ -126,6 +124,7 @@ def test_nested_repeat1(): | |||
| def test_nested_repeat2(): | |||
| logger.info("test_nested_repeat2") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(1) | |||
| data = data.repeat(1) | |||
| @@ -137,6 +136,7 @@ def test_nested_repeat2(): | |||
| def test_nested_repeat3(): | |||
| logger.info("test_nested_repeat3") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(1) | |||
| data = data.repeat(2) | |||
| @@ -148,6 +148,7 @@ def test_nested_repeat3(): | |||
| def test_nested_repeat4(): | |||
| logger.info("test_nested_repeat4") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(1) | |||
| @@ -159,6 +160,7 @@ def test_nested_repeat4(): | |||
| def test_nested_repeat5(): | |||
| logger.info("test_nested_repeat5") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.batch(3) | |||
| data = data.repeat(2) | |||
| @@ -171,6 +173,7 @@ def test_nested_repeat5(): | |||
| def test_nested_repeat6(): | |||
| logger.info("test_nested_repeat6") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.batch(3) | |||
| @@ -183,6 +186,7 @@ def test_nested_repeat6(): | |||
| def test_nested_repeat7(): | |||
| logger.info("test_nested_repeat7") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| @@ -195,6 +199,7 @@ def test_nested_repeat7(): | |||
| def test_nested_repeat8(): | |||
| logger.info("test_nested_repeat8") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.batch(2, drop_remainder=False) | |||
| data = data.repeat(2) | |||
| @@ -210,6 +215,7 @@ def test_nested_repeat8(): | |||
| def test_nested_repeat9(): | |||
| logger.info("test_nested_repeat9") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat() | |||
| data = data.repeat(3) | |||
| @@ -221,6 +227,7 @@ def test_nested_repeat9(): | |||
| def test_nested_repeat10(): | |||
| logger.info("test_nested_repeat10") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(3) | |||
| data = data.repeat() | |||
| @@ -232,6 +239,7 @@ def test_nested_repeat10(): | |||
| def test_nested_repeat11(): | |||
| logger.info("test_nested_repeat11") | |||
| data = ds.GeneratorDataset(generator, ["data"]) | |||
| data = data.repeat(2) | |||
| data = data.repeat(3) | |||
| @@ -12,21 +12,30 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Test TFRecordDataset Ops | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| from util import save_and_check | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from util import save_and_check_dict | |||
| FILES = ["../data/dataset/testTFTestAllTypes/test.data"] | |||
| DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" | |||
| SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | |||
| DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||
| SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | |||
| GENERATE_GOLDEN = False | |||
| def test_case_tf_shape(): | |||
| def test_tfrecord_shape(): | |||
| logger.info("test_tfrecord_shape") | |||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json" | |||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | |||
| ds1 = ds1.batch(2) | |||
| @@ -36,7 +45,8 @@ def test_case_tf_shape(): | |||
| assert len(output_shape[-1]) == 1 | |||
| def test_case_tf_read_all_dataset(): | |||
| def test_tfrecord_read_all_dataset(): | |||
| logger.info("test_tfrecord_read_all_dataset") | |||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" | |||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | |||
| assert ds1.get_dataset_size() == 12 | |||
| @@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset(): | |||
| assert count == 12 | |||
| def test_case_num_samples(): | |||
| def test_tfrecord_num_samples(): | |||
| logger.info("test_tfrecord_num_samples") | |||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" | |||
| ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) | |||
| assert ds1.get_dataset_size() == 8 | |||
| @@ -56,7 +67,8 @@ def test_case_num_samples(): | |||
| assert count == 8 | |||
| def test_case_num_samples2(): | |||
| def test_tfrecord_num_samples2(): | |||
| logger.info("test_tfrecord_num_samples2") | |||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" | |||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | |||
| assert ds1.get_dataset_size() == 7 | |||
| @@ -66,42 +78,41 @@ def test_case_num_samples2(): | |||
| assert count == 7 | |||
| def test_case_tf_shape_2(): | |||
| def test_tfrecord_shape2(): | |||
| logger.info("test_tfrecord_shape2") | |||
| ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) | |||
| ds1 = ds1.batch(2) | |||
| output_shape = ds1.output_shapes() | |||
| assert len(output_shape[-1]) == 2 | |||
| def test_case_tf_file(): | |||
| logger.info("reading data from: {}".format(FILES[0])) | |||
| parameters = {"params": {}} | |||
| def test_tfrecord_files_basic(): | |||
| logger.info("test_tfrecord_files_basic") | |||
| data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| filename = "tfreader_result.npz" | |||
| save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| filename = "tfrecord_files_basic.npz" | |||
| save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_case_tf_file_no_schema(): | |||
| logger.info("reading data from: {}".format(FILES[0])) | |||
| parameters = {"params": {}} | |||
| def test_tfrecord_no_schema(): | |||
| logger.info("test_tfrecord_no_schema") | |||
| data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES) | |||
| filename = "tf_file_no_schema.npz" | |||
| save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| filename = "tfrecord_no_schema.npz" | |||
| save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_case_tf_file_pad(): | |||
| logger.info("reading data from: {}".format(FILES[0])) | |||
| parameters = {"params": {}} | |||
| def test_tfrecord_pad(): | |||
| logger.info("test_tfrecord_pad") | |||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json" | |||
| data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES) | |||
| filename = "tf_file_padBytes10.npz" | |||
| save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||
| filename = "tfrecord_pad_bytes10.npz" | |||
| save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_tf_files(): | |||
| def test_tfrecord_read_files(): | |||
| logger.info("test_tfrecord_read_files") | |||
| pattern = DATASET_ROOT + "/test.data" | |||
| data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| assert sum([1 for _ in data]) == 12 | |||
| @@ -123,7 +134,19 @@ def test_tf_files(): | |||
| assert sum([1 for _ in data]) == 24 | |||
| def test_tf_record_schema(): | |||
| def test_tfrecord_multi_files(): | |||
| logger.info("test_tfrecord_multi_files") | |||
| data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False) | |||
| data1 = data1.repeat(1) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 12 | |||
| def test_tfrecord_schema(): | |||
| logger.info("test_tfrecord_schema") | |||
| schema = ds.Schema() | |||
| schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) | |||
| schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) | |||
| @@ -142,7 +165,8 @@ def test_tf_record_schema(): | |||
| assert np.array_equal(t1, t2) | |||
| def test_tf_record_shuffle(): | |||
| def test_tfrecord_shuffle(): | |||
| logger.info("test_tfrecord_shuffle") | |||
| ds.config.set_seed(1) | |||
| data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) | |||
| data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| @@ -153,7 +177,8 @@ def test_tf_record_shuffle(): | |||
| assert np.array_equal(t1, t2) | |||
| def test_tf_record_shard(): | |||
| def test_tfrecord_shard(): | |||
| logger.info("test_tfrecord_shard") | |||
| tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", | |||
| "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] | |||
| @@ -181,7 +206,8 @@ def test_tf_record_shard(): | |||
| assert set(worker2_res) == set(worker1_res) | |||
| def test_tf_shard_equal_rows(): | |||
| def test_tfrecord_shard_equal_rows(): | |||
| logger.info("test_tfrecord_shard_equal_rows") | |||
| tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", | |||
| "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] | |||
| @@ -209,7 +235,8 @@ def test_tf_shard_equal_rows(): | |||
| assert len(worker4_res) == 40 | |||
| def test_case_tf_file_no_schema_columns_list(): | |||
| def test_tfrecord_no_schema_columns_list(): | |||
| logger.info("test_tfrecord_no_schema_columns_list") | |||
| data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) | |||
| row = data.create_dict_iterator().get_next() | |||
| assert row["col_sint16"] == [-32768] | |||
| @@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list(): | |||
| assert "col_sint32" in str(info.value) | |||
| def test_tf_record_schema_columns_list(): | |||
| def test_tfrecord_schema_columns_list(): | |||
| logger.info("test_tfrecord_schema_columns_list") | |||
| schema = ds.Schema() | |||
| schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) | |||
| schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) | |||
| @@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list(): | |||
| assert "col_sint32" in str(info.value) | |||
| def test_case_invalid_files(): | |||
| def test_tfrecord_invalid_files(): | |||
| logger.info("test_tfrecord_invalid_files") | |||
| valid_file = "../data/dataset/testTFTestAllTypes/test.data" | |||
| invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" | |||
| files = [invalid_file, valid_file, SCHEMA_FILE] | |||
| @@ -266,19 +295,20 @@ def test_case_invalid_files(): | |||
| if __name__ == '__main__': | |||
| test_case_tf_shape() | |||
| test_case_tf_read_all_dataset() | |||
| test_case_num_samples() | |||
| test_case_num_samples2() | |||
| test_case_tf_shape_2() | |||
| test_case_tf_file() | |||
| test_case_tf_file_no_schema() | |||
| test_case_tf_file_pad() | |||
| test_tf_files() | |||
| test_tf_record_schema() | |||
| test_tf_record_shuffle() | |||
| test_tf_record_shard() | |||
| test_tf_shard_equal_rows() | |||
| test_case_tf_file_no_schema_columns_list() | |||
| test_tf_record_schema_columns_list() | |||
| test_case_invalid_files() | |||
| test_tfrecord_shape() | |||
| test_tfrecord_read_all_dataset() | |||
| test_tfrecord_num_samples() | |||
| test_tfrecord_num_samples2() | |||
| test_tfrecord_shape2() | |||
| test_tfrecord_files_basic() | |||
| test_tfrecord_no_schema() | |||
| test_tfrecord_pad() | |||
| test_tfrecord_read_files() | |||
| test_tfrecord_multi_files() | |||
| test_tfrecord_schema() | |||
| test_tfrecord_shuffle() | |||
| test_tfrecord_shard() | |||
| test_tfrecord_shard_equal_rows() | |||
| test_tfrecord_no_schema_columns_list() | |||
| test_tfrecord_schema_columns_list() | |||
| test_tfrecord_invalid_files() | |||