Browse Source

!2980 Prevent empty column names

Merge pull request !2980 from nhussain/empty_column_b
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bccfa48509
4 changed files with 75 additions and 36 deletions
  1. +42
    -28
      mindspore/dataset/core/validator_helpers.py
  2. +1
    -6
      mindspore/dataset/engine/validators.py
  3. +1
    -1
      tests/ut/python/dataset/test_bucket_batch_by_length.py
  4. +31
    -1
      tests/ut/python/dataset/test_dataset_numpy_slices.py

+ 42
- 28
mindspore/dataset/core/validator_helpers.py View File

@@ -123,25 +123,39 @@ def check_valid_detype(type_):




def check_columns(columns, name): def check_columns(columns, name):
"""
Validate strings in column_names.

Args:
columns (list): list of column_names.
name (str): name of columns.

Returns:
Exception: when the value is not correct, otherwise nothing.
"""
type_check(columns, (list, str), name) type_check(columns, (list, str), name)
if isinstance(columns, list): if isinstance(columns, list):
if not columns: if not columns:
raise ValueError("Column names should not be empty")
col_names = ["col_{0}".format(i) for i in range(len(columns))]
raise ValueError("{0} should not be empty".format(name))
for i, column_name in enumerate(columns):
if not column_name:
raise ValueError("{0}[{1}] should not be empty".format(name, i))

col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names) type_check_list(columns, (str,), col_names)




def parse_user_args(method, *args, **kwargs): def parse_user_args(method, *args, **kwargs):
""" """
Parse user arguments in a function
Parse user arguments in a function.


Args: Args:
method (method): a callable function
*args: user passed args
**kwargs: user passed kwargs
method (method): a callable function.
*args: user passed args.
**kwargs: user passed kwargs.


Returns: Returns:
user_filled_args (list): values of what the user passed in for the arguments,
user_filled_args (list): values of what the user passed in for the arguments.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
""" """
sig = inspect.signature(method) sig = inspect.signature(method)
@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):


def type_check_list(args, types, arg_names): def type_check_list(args, types, arg_names):
""" """
Check the type of each parameter in the list
Check the type of each parameter in the list.


Args: Args:
args (list, tuple): a list or tuple of any variable
types (tuple): tuple of all valid types for arg
arg_names (list, tuple of str): the names of args
args (list, tuple): a list or tuple of any variable.
types (tuple): tuple of all valid types for arg.
arg_names (list, tuple of str): the names of args.


Returns: Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing.
""" """
type_check(args, (list, tuple,), arg_names) type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names): if len(args) != len(arg_names):
@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):


def type_check(arg, types, arg_name): def type_check(arg, types, arg_name):
""" """
Check the type of the parameter
Check the type of the parameter.


Args: Args:
arg : any variable
types (tuple): tuple of all valid types for arg
arg_name (str): the name of arg
arg : any variable.
types (tuple): tuple of all valid types for arg.
arg_name (str): the name of arg.


Returns: Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing.
""" """
# handle special case of booleans being a subclass of ints # handle special case of booleans being a subclass of ints
print_value = '\"\"' if repr(arg) == repr('') else arg print_value = '\"\"' if repr(arg) == repr('') else arg
@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):


def check_filename(path): def check_filename(path):
""" """
check the filename in the path
check the filename in the path.


Args: Args:
path (str): the path
path (str): the path.


Returns: Returns:
Exception: when error
Exception: when error.
""" """
if not isinstance(path, str): if not isinstance(path, str):
raise TypeError("path: {} is not string".format(path)) raise TypeError("path: {} is not string".format(path))
@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
""" """
Check for valid shuffle, sampler, num_shards, and shard_id inputs. Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args: Args:
param_dict (dict): param_dict
param_dict (dict): param_dict.


Returns: Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error.
""" """
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):


def check_padding_options(param_dict): def check_padding_options(param_dict):
""" """
Check for valid padded_sample and num_padded of padded samples
Check for valid padded_sample and num_padded of padded samples.


Args: Args:
param_dict (dict): param_dict
param_dict (dict): param_dict.


Returns: Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error.
""" """


columns_list = param_dict.get('columns_list') columns_list = param_dict.get('columns_list')
@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
Check if the input parameter is list or numpy.ndarray. Check if the input parameter is list or numpy.ndarray.


Args: Args:
param (list, nd.ndarray): param
param_name (str): param_name
param (list, nd.ndarray): param.
param_name (str): param_name.


Returns: Returns:
Exception: TypeError if error
Exception: TypeError if error.
""" """


type_check(param, (list, np.ndarray), param_name) type_check(param, (list, np.ndarray), param_name)


+ 1
- 6
mindspore/dataset/engine/validators.py View File

@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)


# check column_names: must be list of string. # check column_names: must be list of string.
if not column_names:
raise ValueError("column_names cannot be empty")

all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
check_columns(column_names, "column_names")


if element_length_function is None and len(column_names) != 1: if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")


+ 1
- 1
tests/ut/python/dataset/test_bucket_batch_by_length.py View File

@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():


with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value)
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(info.value)


with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes)


+ 31
- 1
tests/ut/python/dataset/test_dataset_numpy_slices.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.dataset as de import mindspore.dataset as de
from mindspore import log as logger from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():




def test_numpy_slices_sequential_sampler(): def test_numpy_slices_sequential_sampler():

logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")


np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
assert np.equal(data[0], np_data[i % 8]).all() assert np.equal(data[0], np_data[i % 8]).all()




def test_numpy_slices_invalid_column_names_type():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]

with pytest.raises(TypeError) as err:
de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(err.value)


def test_numpy_slices_invalid_column_names_string():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]

with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
assert "column_names[0] should not be empty" in str(err.value)


def test_numpy_slices_invalid_empty_column_names():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]

with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
assert "column_names should not be empty" in str(err.value)


if __name__ == "__main__": if __name__ == "__main__":
test_numpy_slices_list_1() test_numpy_slices_list_1()
test_numpy_slices_list_2() test_numpy_slices_list_2()
@@ -197,3 +224,6 @@ if __name__ == "__main__":
test_numpy_slices_num_samplers() test_numpy_slices_num_samplers()
test_numpy_slices_distributed_sampler() test_numpy_slices_distributed_sampler()
test_numpy_slices_sequential_sampler() test_numpy_slices_sequential_sampler()
test_numpy_slices_invalid_column_names_type()
test_numpy_slices_invalid_column_names_string()
test_numpy_slices_invalid_empty_column_names()

Loading…
Cancel
Save