Browse Source

added check for invalid type for boolean args

tags/v0.6.0-beta
peilin-wang 5 years ago
parent
commit
29aa589972
2 changed files with 20 additions and 0 deletions
  1. +7
    -0
      mindspore/dataset/engine/validators.py
  2. +13
    -0
      tests/ut/python/dataset/test_bucket_batch_by_length.py

+ 7
- 0
mindspore/dataset/engine/validators.py View File

@@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method):
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
check_param_type(nreq_param_list, param_dict, list)

nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
check_param_type(nbool_param_list, param_dict, bool)

# check column_names: must be list of string.
column_names = param_dict.get("column_names")

if not column_names:
raise ValueError("column_names cannot be empty")

all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")


+ 13
- 0
tests/ut/python/dataset/test_bucket_batch_by_length.py View File

@@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input():
negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3]

invalid_type_pad_to_bucket_boundary = ""
invalid_type_drop_remainder = ""

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value)
@@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input():
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, invalid_type_pad_to_bucket_boundary)
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, False, invalid_type_drop_remainder)
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)


def test_bucket_batch_multi_bucket_no_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])


Loading…
Cancel
Save