|
|
|
@@ -390,7 +390,6 @@ def filter_func_Partial_0(col1, col2, col3, col4): |
|
|
|
|
|
|
|
# test with row_data_buffer > 1 |
|
|
|
def test_filter_by_generator_Partial0(): |
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg') |
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) |
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) |
|
|
|
dataset_zip = ds.zip((dataset1, dataset2)) |
|
|
|
@@ -404,7 +403,6 @@ def test_filter_by_generator_Partial0(): |
|
|
|
|
|
|
|
# test with row_data_buffer > 1 |
|
|
|
def test_filter_by_generator_Partial1(): |
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg') |
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) |
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) |
|
|
|
dataset_zip = ds.zip((dataset1, dataset2)) |
|
|
|
@@ -419,7 +417,6 @@ def test_filter_by_generator_Partial1(): |
|
|
|
|
|
|
|
# test with row_data_buffer > 1 |
|
|
|
def test_filter_by_generator_Partial2(): |
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg') |
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"]) |
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"]) |
|
|
|
|
|
|
|
@@ -454,7 +451,6 @@ def generator_big(maxid=20): |
|
|
|
|
|
|
|
# test with row_data_buffer > 1 |
|
|
|
def test_filter_by_generator_Partial(): |
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg') |
|
|
|
dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"]) |
|
|
|
dataset_s = dataset.shuffle(4) |
|
|
|
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) |
|
|
|
@@ -473,7 +469,6 @@ def filter_func_cifar(col1, col2): |
|
|
|
# test with cifar10 |
|
|
|
def test_filte_case_dataset_cifar10(): |
|
|
|
DATA_DIR_10 = "../data/dataset/testCifar10Data" |
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg') |
|
|
|
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False) |
|
|
|
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) |
|
|
|
for item in dataset_f1.create_dict_iterator(): |
|
|
|
|