|
|
|
@@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input(): |
|
|
|
negative_bucket_batch_sizes = [1, 2, 3, -4] |
|
|
|
zero_bucket_batch_sizes = [0, 1, 2, 3] |
|
|
|
|
|
|
|
invalid_type_pad_to_bucket_boundary = "" |
|
|
|
invalid_type_drop_remainder = "" |
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) |
|
|
|
assert "column_names should be a list of str" in str(info.value) |
|
|
|
@@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input(): |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) |
|
|
|
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) |
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, |
|
|
|
None, None, invalid_type_pad_to_bucket_boundary) |
|
|
|
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value) |
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, |
|
|
|
None, None, False, invalid_type_drop_remainder) |
|
|
|
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value) |
|
|
|
|
|
|
|
|
|
|
|
def test_bucket_batch_multi_bucket_no_padding(): |
|
|
|
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) |
|
|
|
|