|
|
|
@@ -84,12 +84,11 @@ def test_pipeline(): |
|
|
|
num_parallel_workers_original = ds.config.get_num_parallel_workers() |
|
|
|
|
|
|
|
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) |
|
|
|
ds.config.set_num_parallel_workers(2) |
|
|
|
data1 = data1.map(input_columns=["image"], operations=[c_vision.Decode(True)]) |
|
|
|
ds.serialize(data1, "testpipeline.json") |
|
|
|
|
|
|
|
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) |
|
|
|
ds.config.set_num_parallel_workers(4) |
|
|
|
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original, |
|
|
|
shuffle=False) |
|
|
|
data2 = data2.map(input_columns=["image"], operations=[c_vision.Decode(True)]) |
|
|
|
ds.serialize(data2, "testpipeline2.json") |
|
|
|
|
|
|
|
|