|
|
|
@@ -80,6 +80,22 @@ def test_slice_slice_obj_3s(): |
|
|
|
slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3)) |
|
|
|
|
|
|
|
|
|
|
|
def test_slice_multiple_rows(): |
|
|
|
dataset = [[1, 2], [3, 4, 5], [1], [1, 2, 3, 4, 5, 6, 7]] |
|
|
|
|
|
|
|
def gen(): |
|
|
|
for row in dataset: |
|
|
|
yield (np.array(row),) |
|
|
|
|
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"]) |
|
|
|
indexing = slice(0, 4) |
|
|
|
data = data.map(operations=ops.Slice(indexing)) |
|
|
|
for i, d in enumerate(data): |
|
|
|
array = np.array(dataset[i]) |
|
|
|
array = array[indexing] |
|
|
|
np.testing.assert_array_equal(array, d[0]) |
|
|
|
|
|
|
|
|
|
|
|
def test_slice_slice_obj_3s_double(): |
|
|
|
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1)) |
|
|
|
slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1)) |
|
|
|
@@ -217,3 +233,4 @@ if __name__ == "__main__": |
|
|
|
test_slice_slice_obj_1s_str() |
|
|
|
test_slice_slice_obj_neg_str() |
|
|
|
test_slice_exceptions_str() |
|
|
|
test_slice_multiple_rows() |