Browse Source

!2737 makes 0 an invaild bucket size

Merge pull request !2737 from Peilin/zero-input-bucket-check
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
d08a89ab87
2 changed files with 8 additions and 3 deletions
  1. +2
    -2
      mindspore/dataset/engine/validators.py
  2. +6
    -1
      tests/ut/python/dataset/test_bucket_batch_by_length.py

+ 2
- 2
mindspore/dataset/engine/validators.py View File

@@ -643,9 +643,9 @@ def check_bucket_batch_by_length(method):
if not all_int: if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.") raise TypeError("bucket_batch_sizes should be a list of int.")


all_non_negative = all(item >= 0 for item in bucket_batch_sizes)
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
if not all_non_negative: if not all_non_negative:
raise ValueError("bucket_batch_sizes cannot contain any negative numbers.")
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")


if param_dict.get('pad_info') is not None: if param_dict.get('pad_info') is not None:
check_type(param_dict["pad_info"], "pad_info", dict) check_type(param_dict["pad_info"], "pad_info", dict)


+ 6
- 1
tests/ut/python/dataset/test_bucket_batch_by_length.py View File

@@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input():
bucket_batch_sizes = [1, 1, 1, 1] bucket_batch_sizes = [1, 1, 1, 1]
invalid_bucket_batch_sizes = ["1", "2", "3", "4"] invalid_bucket_batch_sizes = ["1", "2", "3", "4"]
negative_bucket_batch_sizes = [1, 2, 3, -4] negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3]


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
@@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input():


with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes)
assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)


with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)


Loading…
Cancel
Save