Browse Source

!3067 Cleanup dataset UT: Remove unneeded tf data files and tests

Merge pull request !3067 from cathwong/ckw_dataset_ut_cleanup6
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ba0143402c
19 changed files with 148 additions and 329 deletions
  1. +1
    -1
      tests/ut/cpp/dataset/rename_op_test.cc
  2. +2
    -2
      tests/ut/cpp/dataset/zip_op_test.cc
  3. BIN
      tests/ut/data/dataset/golden/repeat_list_result.npz
  4. BIN
      tests/ut/data/dataset/golden/repeat_result.npz
  5. BIN
      tests/ut/data/dataset/golden/tf_file_no_schema.npz
  6. BIN
      tests/ut/data/dataset/golden/tf_file_padBytes10.npz
  7. BIN
      tests/ut/data/dataset/golden/tfreader_result.npz
  8. BIN
      tests/ut/data/dataset/golden/tfrecord_files_basic.npz
  9. BIN
      tests/ut/data/dataset/golden/tfrecord_no_schema.npz
  10. BIN
      tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz
  11. +0
    -11
      tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json
  12. BIN
      tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data
  13. +0
    -11
      tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json
  14. BIN
      tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data
  15. +0
    -204
      tests/ut/python/dataset/test_datasets_imagenet.py
  16. +0
    -40
      tests/ut/python/dataset/test_datasets_imagenet_distribution.py
  17. +51
    -4
      tests/ut/python/dataset/test_onehot_op.py
  18. +19
    -11
      tests/ut/python/dataset/test_repeat.py
  19. +75
    -45
      tests/ut/python/dataset/test_tfreader_op.py

+ 1
- 1
tests/ut/cpp/dataset/rename_op_test.cc View File

@@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp

std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
.SetDatasetFilesList({dataset_path})


+ 2
- 2
tests/ut/cpp/dataset/zip_op_test.cc View File

@@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
auto my_tree = std::make_shared<ExecutionTree>();
// Creating TFReaderOp

std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()
@@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG(INFO) << "UT test TestZipRepeat.";
auto my_tree = std::make_shared<ExecutionTree>();

std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data";
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
rc = TFReaderOp::Builder()


BIN
tests/ut/data/dataset/golden/repeat_list_result.npz View File


BIN
tests/ut/data/dataset/golden/repeat_result.npz View File


BIN
tests/ut/data/dataset/golden/tf_file_no_schema.npz View File


BIN
tests/ut/data/dataset/golden/tf_file_padBytes10.npz View File


BIN
tests/ut/data/dataset/golden/tfreader_result.npz View File


BIN
tests/ut/data/dataset/golden/tfrecord_files_basic.npz View File


BIN
tests/ut/data/dataset/golden/tfrecord_no_schema.npz View File


BIN
tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz View File


+ 0
- 11
tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json View File

@@ -1,11 +0,0 @@
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"label": {
"type": "int64",
"rank": 1,
"t_impl": "flex"
}
}
}

BIN
tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data View File


+ 0
- 11
tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json View File

@@ -1,11 +0,0 @@
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"image": {
"type": "uint8",
"rank": 1,
"t_impl": "cvmat"
}
}
}

BIN
tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data View File


+ 0
- 204
tests/ut/python/dataset/test_datasets_imagenet.py View File

@@ -1,204 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger

DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_case_repeat():
"""
a simple repeat operation.
"""
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)

num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1

logger.info("Number of data in data1: {}".format(num_iter))


def test_case_shuffle():
"""
a simple shuffle operation.
"""
logger.info("Test Simple Shuffle")
# define parameters
buffer_size = 8
seed = 10

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)

for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))


def test_case_0():
"""
Test Repeat then Shuffle
"""
logger.info("Test Repeat then Shuffle")
# define parameters
repeat_count = 2
buffer_size = 7
seed = 9

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(repeat_count)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)

num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1

logger.info("Number of data in data1: {}".format(num_iter))


def test_case_0_reverse():
"""
Test Shuffle then Repeat
"""
logger.info("Test Shuffle then Repeat")
# define parameters
repeat_count = 2
buffer_size = 10
seed = 9

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.repeat(repeat_count)

num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1

logger.info("Number of data in data1: {}".format(num_iter))


def test_case_3():
"""
Test Map
"""
logger.info("Test Map Rescale and Resize, then Shuffle")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
# define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224

# define map operations
decode_op = vision.Decode()
rescale_op = vision.Rescale(rescale, shift)
# resize_op = vision.Resize(resize_height, resize_width,
# InterpolationMode.DE_INTER_LINEAR) # Bilinear mode
resize_op = vision.Resize((resize_height, resize_width))

# apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)

# # apply ont-hot encoding on labels
num_classes = 4
one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)
#
# # apply Datasets
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script
data1 = data1.batch(batch_size, drop_remainder=True)

num_iter = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1

logger.info("Number of data in data1: {}".format(num_iter))


if __name__ == '__main__':
logger.info('===========now test Repeat============')
# logger.info('Simple Repeat')
test_case_repeat()
logger.info('\n')

logger.info('===========now test Shuffle===========')
# logger.info('Simple Shuffle')
test_case_shuffle()
logger.info('\n')

# Note: cannot work with different shapes, hence not for image
# logger.info('===========now test Batch=============')
# # logger.info('Simple Batch')
# test_case_batch()
# logger.info('\n')

logger.info('===========now test case 0============')
# logger.info('Repeat then Shuffle')
test_case_0()
logger.info('\n')

logger.info('===========now test case 0 reverse============')
# # logger.info('Shuffle then Repeat')
test_case_0_reverse()
logger.info('\n')

# logger.info('===========now test case 1============')
# # logger.info('Repeat with Batch')
# test_case_1()
# logger.info('\n')

# logger.info('===========now test case 2============')
# # logger.info('Batch with Shuffle')
# test_case_2()
# logger.info('\n')

# for image augmentation only
logger.info('===========now test case 3============')
logger.info('Map then Shuffle')
test_case_3()
logger.info('\n')

+ 0
- 40
tests/ut/python/dataset/test_datasets_imagenet_distribution.py View File

@@ -1,40 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger

DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]

SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"


def test_tf_file_normal():
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_iter += 1

logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 12


if __name__ == '__main__':
logger.info('=======test normal=======')
test_tf_file_normal()

+ 51
- 4
tests/ut/python/dataset/test_onehot_op.py View File

@@ -13,12 +13,13 @@
# limitations under the License.
# ==============================================================================
"""
Testing the one_hot op in DE
Testing the OneHot Op
"""
import numpy as np

import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
from util import diff_mse

@@ -37,15 +38,15 @@ def one_hot(index, depth):

def test_one_hot():
"""
Test one_hot
Test OneHot Tensor Operator
"""
logger.info("Test one_hot")
logger.info("test_one_hot")

depth = 10

# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
one_hot_op = data_trans.OneHot(depth)
one_hot_op = data_trans.OneHot(num_classes=depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"])

# Second dataset
@@ -58,8 +59,54 @@ def test_one_hot():
label2 = one_hot(item2["label"][0], depth)
mse = diff_mse(label1, label2)
logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse))
assert mse == 0
num_iter += 1
assert num_iter == 3

def test_one_hot_post_aug():
"""
Test One Hot Encoding after Multiple Data Augmentation Operators
"""
logger.info("test_one_hot_post_aug")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)

# Define data augmentation parameters
rescale = 1.0 / 255.0
shift = 0.0
resize_height, resize_width = 224, 224

# Define map operations
decode_op = c_vision.Decode()
rescale_op = c_vision.Rescale(rescale, shift)
resize_op = c_vision.Resize((resize_height, resize_width))

# Apply map operations on images
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=rescale_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)

# Apply one-hot encoding on labels
depth = 4
one_hot_encode = data_trans.OneHot(depth)
data1 = data1.map(input_columns=["label"], operations=one_hot_encode)

# Apply datasets ops
buffer_size = 100
seed = 10
batch_size = 2
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data1 = data1.batch(batch_size, drop_remainder=True)

num_iter = 0
for item in data1.create_dict_iterator():
logger.info("image is: {}".format(item["image"]))
logger.info("label is: {}".format(item["label"]))
num_iter += 1

assert num_iter == 1


if __name__ == "__main__":
test_one_hot()
test_one_hot_post_aug()

+ 19
- 11
tests/ut/python/dataset/test_repeat.py View File

@@ -12,25 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Repeat Op
"""
import numpy as np
from util import save_and_check

import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
from util import save_and_check_dict

DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False

IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"

DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"

GENERATE_GOLDEN = False


def test_tf_repeat_01():
"""
@@ -39,14 +38,13 @@ def test_tf_repeat_01():
logger.info("Test Simple Repeat")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False)
data1 = data1.repeat(repeat_count)

filename = "repeat_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)


def test_tf_repeat_02():
@@ -99,14 +97,13 @@ def test_tf_repeat_04():
logger.info("Test Simple Repeat Column List")
# define parameters
repeat_count = 2
parameters = {"params": {'repeat_count': repeat_count}}
columns_list = ["col_sint64", "col_sint32"]
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False)
data1 = data1.repeat(repeat_count)

filename = "repeat_list_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)


def generator():
@@ -115,6 +112,7 @@ def generator():


def test_nested_repeat1():
logger.info("test_nested_repeat1")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
@@ -126,6 +124,7 @@ def test_nested_repeat1():


def test_nested_repeat2():
logger.info("test_nested_repeat2")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(1)
@@ -137,6 +136,7 @@ def test_nested_repeat2():


def test_nested_repeat3():
logger.info("test_nested_repeat3")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(1)
data = data.repeat(2)
@@ -148,6 +148,7 @@ def test_nested_repeat3():


def test_nested_repeat4():
logger.info("test_nested_repeat4")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(1)
@@ -159,6 +160,7 @@ def test_nested_repeat4():


def test_nested_repeat5():
logger.info("test_nested_repeat5")
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(3)
data = data.repeat(2)
@@ -171,6 +173,7 @@ def test_nested_repeat5():


def test_nested_repeat6():
logger.info("test_nested_repeat6")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.batch(3)
@@ -183,6 +186,7 @@ def test_nested_repeat6():


def test_nested_repeat7():
logger.info("test_nested_repeat7")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)
@@ -195,6 +199,7 @@ def test_nested_repeat7():


def test_nested_repeat8():
logger.info("test_nested_repeat8")
data = ds.GeneratorDataset(generator, ["data"])
data = data.batch(2, drop_remainder=False)
data = data.repeat(2)
@@ -210,6 +215,7 @@ def test_nested_repeat8():


def test_nested_repeat9():
logger.info("test_nested_repeat9")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat()
data = data.repeat(3)
@@ -221,6 +227,7 @@ def test_nested_repeat9():


def test_nested_repeat10():
logger.info("test_nested_repeat10")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(3)
data = data.repeat()
@@ -232,6 +239,7 @@ def test_nested_repeat10():


def test_nested_repeat11():
logger.info("test_nested_repeat11")
data = ds.GeneratorDataset(generator, ["data"])
data = data.repeat(2)
data = data.repeat(3)


+ 75
- 45
tests/ut/python/dataset/test_tfreader_op.py View File

@@ -12,21 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test TFRecordDataset Ops
"""
import numpy as np
import pytest
from util import save_and_check

import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict

FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
"../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
GENERATE_GOLDEN = False


def test_case_tf_shape():
def test_tfrecord_shape():
logger.info("test_tfrecord_shape")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2)
@@ -36,7 +45,8 @@ def test_case_tf_shape():
assert len(output_shape[-1]) == 1


def test_case_tf_read_all_dataset():
def test_tfrecord_read_all_dataset():
logger.info("test_tfrecord_read_all_dataset")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 12
@@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset():
assert count == 12


def test_case_num_samples():
def test_tfrecord_num_samples():
logger.info("test_tfrecord_num_samples")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
assert ds1.get_dataset_size() == 8
@@ -56,7 +67,8 @@ def test_case_num_samples():
assert count == 8


def test_case_num_samples2():
def test_tfrecord_num_samples2():
logger.info("test_tfrecord_num_samples2")
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
ds1 = ds.TFRecordDataset(FILES, schema_file)
assert ds1.get_dataset_size() == 7
@@ -66,42 +78,41 @@ def test_case_num_samples2():
assert count == 7


def test_case_tf_shape_2():
def test_tfrecord_shape2():
logger.info("test_tfrecord_shape2")
ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
ds1 = ds1.batch(2)
output_shape = ds1.output_shapes()
assert len(output_shape[-1]) == 2


def test_case_tf_file():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_files_basic():
logger.info("test_tfrecord_files_basic")

data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
filename = "tfreader_result.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_files_basic.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)


def test_case_tf_file_no_schema():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_no_schema():
logger.info("test_tfrecord_no_schema")

data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
filename = "tf_file_no_schema.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_no_schema.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)


def test_case_tf_file_pad():
logger.info("reading data from: {}".format(FILES[0]))
parameters = {"params": {}}
def test_tfrecord_pad():
logger.info("test_tfrecord_pad")

schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
filename = "tf_file_padBytes10.npz"
save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
filename = "tfrecord_pad_bytes10.npz"
save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)


def test_tf_files():
def test_tfrecord_read_files():
logger.info("test_tfrecord_read_files")
pattern = DATASET_ROOT + "/test.data"
data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert sum([1 for _ in data]) == 12
@@ -123,7 +134,19 @@ def test_tf_files():
assert sum([1 for _ in data]) == 24


def test_tf_record_schema():
def test_tfrecord_multi_files():
logger.info("test_tfrecord_multi_files")
data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
data1 = data1.repeat(1)
num_iter = 0
for _ in data1.create_dict_iterator():
num_iter += 1

assert num_iter == 12


def test_tfrecord_schema():
logger.info("test_tfrecord_schema")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
@@ -142,7 +165,8 @@ def test_tf_record_schema():
assert np.array_equal(t1, t2)


def test_tf_record_shuffle():
def test_tfrecord_shuffle():
logger.info("test_tfrecord_shuffle")
ds.config.set_seed(1)
data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
@@ -153,7 +177,8 @@ def test_tf_record_shuffle():
assert np.array_equal(t1, t2)


def test_tf_record_shard():
def test_tfrecord_shard():
logger.info("test_tfrecord_shard")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]

@@ -181,7 +206,8 @@ def test_tf_record_shard():
assert set(worker2_res) == set(worker1_res)


def test_tf_shard_equal_rows():
def test_tfrecord_shard_equal_rows():
logger.info("test_tfrecord_shard_equal_rows")
tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
"../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]

@@ -209,7 +235,8 @@ def test_tf_shard_equal_rows():
assert len(worker4_res) == 40


def test_case_tf_file_no_schema_columns_list():
def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
assert row["col_sint16"] == [-32768]
@@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list():
assert "col_sint32" in str(info.value)


def test_tf_record_schema_columns_list():
def test_tfrecord_schema_columns_list():
logger.info("test_tfrecord_schema_columns_list")
schema = ds.Schema()
schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
@@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list():
assert "col_sint32" in str(info.value)


def test_case_invalid_files():
def test_tfrecord_invalid_files():
logger.info("test_tfrecord_invalid_files")
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE]
@@ -266,19 +295,20 @@ def test_case_invalid_files():


if __name__ == '__main__':
test_case_tf_shape()
test_case_tf_read_all_dataset()
test_case_num_samples()
test_case_num_samples2()
test_case_tf_shape_2()
test_case_tf_file()
test_case_tf_file_no_schema()
test_case_tf_file_pad()
test_tf_files()
test_tf_record_schema()
test_tf_record_shuffle()
test_tf_record_shard()
test_tf_shard_equal_rows()
test_case_tf_file_no_schema_columns_list()
test_tf_record_schema_columns_list()
test_case_invalid_files()
test_tfrecord_shape()
test_tfrecord_read_all_dataset()
test_tfrecord_num_samples()
test_tfrecord_num_samples2()
test_tfrecord_shape2()
test_tfrecord_files_basic()
test_tfrecord_no_schema()
test_tfrecord_pad()
test_tfrecord_read_files()
test_tfrecord_multi_files()
test_tfrecord_schema()
test_tfrecord_shuffle()
test_tfrecord_shard()
test_tfrecord_shard_equal_rows()
test_tfrecord_no_schema_columns_list()
test_tfrecord_schema_columns_list()
test_tfrecord_invalid_files()

Loading…
Cancel
Save