|
|
|
@@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input(): |
|
|
|
bucket_batch_sizes = [1, 1, 1, 1] |
|
|
|
invalid_bucket_batch_sizes = ["1", "2", "3", "4"] |
|
|
|
negative_bucket_batch_sizes = [1, 2, 3, -4] |
|
|
|
zero_bucket_batch_sizes = [0, 1, 2, 3] |
|
|
|
|
|
|
|
with pytest.raises(TypeError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) |
|
|
|
@@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input(): |
|
|
|
|
|
|
|
with pytest.raises(ValueError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) |
|
|
|
assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) |
|
|
|
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) |
|
|
|
|
|
|
|
with pytest.raises(ValueError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes) |
|
|
|
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) |
|
|
|
|
|
|
|
with pytest.raises(ValueError) as info: |
|
|
|
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) |
|
|
|
|