|
|
|
@@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain(): |
|
|
|
# Use 1 statement to add child sampler |
|
|
|
np_data = [1, 2, 3, 4] |
|
|
|
sampler = ds.SequentialSampler(start_index=1, num_samples=2) |
|
|
|
sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) |
|
|
|
sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) |
|
|
|
data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) |
|
|
|
|
|
|
|
# Verify dataset size |
|
|
|
data1_size = data1.get_dataset_size() |
|
|
|
logger.info("dataset size is: {}".format(data1_size)) |
|
|
|
assert data1_size == 4 |
|
|
|
assert data1_size == 1 |
|
|
|
|
|
|
|
# Verify number of rows |
|
|
|
assert sum([1 for _ in data1]) == 4 |
|
|
|
assert sum([1 for _ in data1]) == 1 |
|
|
|
|
|
|
|
# Verify dataset contents |
|
|
|
res = [] |
|
|
|
|