|
|
|
@@ -36,6 +36,11 @@ def generate_2_columns(n): |
|
|
|
yield (np.array([i]), np.array([j for j in range(i + 1)])) |
|
|
|
|
|
|
|
|
|
|
|
def generate_3_columns(n): |
|
|
|
for i in range(n): |
|
|
|
yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)])) |
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_invalid_input(): |
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) |
|
|
|
|
|
|
|
@@ -382,6 +387,48 @@ def test_bucket_batch_multi_column(): |
|
|
|
assert same_shape_output == same_shape_expected_output |
|
|
|
assert variable_shape_output == variable_shape_expected_output |
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_three_columns(): |
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"]) |
|
|
|
|
|
|
|
column_names = ["same_shape2"] |
|
|
|
bucket_boundaries = [6, 12] |
|
|
|
bucket_batch_sizes = [5, 5, 1] |
|
|
|
element_length_function = (lambda x: x[0] % 3) |
|
|
|
pad_info = {} |
|
|
|
|
|
|
|
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, |
|
|
|
bucket_batch_sizes, element_length_function, |
|
|
|
pad_info) |
|
|
|
|
|
|
|
same_shape_expected_output = [[[0], [1], [2], [3], [4]], |
|
|
|
[[5], [6], [7], [8], [9]]] |
|
|
|
same_shape2_expected_output = [[[1], [2], [3], [4], [5]], |
|
|
|
[[6], [7], [8], [9], [10]]] |
|
|
|
variable_shape_expected_output = [[[0, 0, 0, 0, 0], |
|
|
|
[0, 1, 0, 0, 0], |
|
|
|
[0, 1, 2, 0, 0], |
|
|
|
[0, 1, 2, 3, 0], |
|
|
|
[0, 1, 2, 3, 4]], |
|
|
|
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], |
|
|
|
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0], |
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0], |
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0], |
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] |
|
|
|
|
|
|
|
same_shape_output = [] |
|
|
|
same_shape2_output = [] |
|
|
|
variable_shape_output = [] |
|
|
|
for data in dataset.create_dict_iterator(num_epochs=1): |
|
|
|
same_shape_output.append(data["same_shape"].tolist()) |
|
|
|
same_shape2_output.append(data["same_shape2"].tolist()) |
|
|
|
variable_shape_output.append(data["variable_shape"].tolist()) |
|
|
|
|
|
|
|
assert same_shape_output == same_shape_expected_output |
|
|
|
assert same_shape2_output == same_shape2_expected_output |
|
|
|
assert variable_shape_output == variable_shape_expected_output |
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_get_dataset_size(): |
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) |
|
|
|
|
|
|
|
@@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size(): |
|
|
|
assert data_size == num_rows |
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_invalid_column(): |
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) |
|
|
|
|
|
|
|
column_names = ["invalid_column"] |
|
|
|
bucket_boundaries = [1, 2, 3] |
|
|
|
bucket_batch_sizes = [3, 3, 2, 2] |
|
|
|
element_length_function = (lambda x: x[0] % 4) |
|
|
|
|
|
|
|
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, |
|
|
|
bucket_batch_sizes, element_length_function) |
|
|
|
|
|
|
|
with pytest.raises(RuntimeError) as info: |
|
|
|
num_rows = 0 |
|
|
|
for _ in dataset.create_dict_iterator(num_epochs=1): |
|
|
|
num_rows += 1 |
|
|
|
|
|
|
|
assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_bucket_batch_invalid_input() |
|
|
|
test_bucket_batch_multi_bucket_no_padding() |
|
|
|
@@ -413,4 +479,6 @@ if __name__ == '__main__': |
|
|
|
test_bucket_batch_drop_remainder() |
|
|
|
test_bucket_batch_default_length_function() |
|
|
|
test_bucket_batch_multi_column() |
|
|
|
test_bucket_batch_three_columns() |
|
|
|
test_bucket_batch_get_dataset_size() |
|
|
|
test_bucket_batch_invalid_column() |