|
|
|
@@ -123,25 +123,39 @@ def check_valid_detype(type_): |
|
|
|
|
|
|
|
|
|
|
|
def check_columns(columns, name): |
|
|
|
""" |
|
|
|
Validate strings in column_names. |
|
|
|
|
|
|
|
Args: |
|
|
|
columns (list): list of column_names. |
|
|
|
name (str): name of columns. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: when the value is not correct, otherwise nothing. |
|
|
|
""" |
|
|
|
type_check(columns, (list, str), name) |
|
|
|
if isinstance(columns, list): |
|
|
|
if not columns: |
|
|
|
raise ValueError("Column names should not be empty") |
|
|
|
col_names = ["col_{0}".format(i) for i in range(len(columns))] |
|
|
|
raise ValueError("{0} should not be empty".format(name)) |
|
|
|
for i, column_name in enumerate(columns): |
|
|
|
if not column_name: |
|
|
|
raise ValueError("{0}[{1}] should not be empty".format(name, i)) |
|
|
|
|
|
|
|
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] |
|
|
|
type_check_list(columns, (str,), col_names) |
|
|
|
|
|
|
|
|
|
|
|
def parse_user_args(method, *args, **kwargs): |
|
|
|
""" |
|
|
|
Parse user arguments in a function |
|
|
|
Parse user arguments in a function. |
|
|
|
|
|
|
|
Args: |
|
|
|
method (method): a callable function |
|
|
|
*args: user passed args |
|
|
|
**kwargs: user passed kwargs |
|
|
|
method (method): a callable function. |
|
|
|
*args: user passed args. |
|
|
|
**kwargs: user passed kwargs. |
|
|
|
|
|
|
|
Returns: |
|
|
|
user_filled_args (list): values of what the user passed in for the arguments, |
|
|
|
user_filled_args (list): values of what the user passed in for the arguments. |
|
|
|
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. |
|
|
|
""" |
|
|
|
sig = inspect.signature(method) |
|
|
|
@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs): |
|
|
|
|
|
|
|
def type_check_list(args, types, arg_names): |
|
|
|
""" |
|
|
|
Check the type of each parameter in the list |
|
|
|
Check the type of each parameter in the list. |
|
|
|
|
|
|
|
Args: |
|
|
|
args (list, tuple): a list or tuple of any variable |
|
|
|
types (tuple): tuple of all valid types for arg |
|
|
|
arg_names (list, tuple of str): the names of args |
|
|
|
args (list, tuple): a list or tuple of any variable. |
|
|
|
types (tuple): tuple of all valid types for arg. |
|
|
|
arg_names (list, tuple of str): the names of args. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: when the type is not correct, otherwise nothing |
|
|
|
Exception: when the type is not correct, otherwise nothing. |
|
|
|
""" |
|
|
|
type_check(args, (list, tuple,), arg_names) |
|
|
|
if len(args) != len(arg_names): |
|
|
|
@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names): |
|
|
|
|
|
|
|
def type_check(arg, types, arg_name): |
|
|
|
""" |
|
|
|
Check the type of the parameter |
|
|
|
Check the type of the parameter. |
|
|
|
|
|
|
|
Args: |
|
|
|
arg : any variable |
|
|
|
types (tuple): tuple of all valid types for arg |
|
|
|
arg_name (str): the name of arg |
|
|
|
arg : any variable. |
|
|
|
types (tuple): tuple of all valid types for arg. |
|
|
|
arg_name (str): the name of arg. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: when the type is not correct, otherwise nothing |
|
|
|
Exception: when the type is not correct, otherwise nothing. |
|
|
|
""" |
|
|
|
# handle special case of booleans being a subclass of ints |
|
|
|
print_value = '\"\"' if repr(arg) == repr('') else arg |
|
|
|
@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name): |
|
|
|
|
|
|
|
def check_filename(path): |
|
|
|
""" |
|
|
|
check the filename in the path |
|
|
|
check the filename in the path. |
|
|
|
|
|
|
|
Args: |
|
|
|
path (str): the path |
|
|
|
path (str): the path. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: when error |
|
|
|
Exception: when error. |
|
|
|
""" |
|
|
|
if not isinstance(path, str): |
|
|
|
raise TypeError("path: {} is not string".format(path)) |
|
|
|
@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict): |
|
|
|
""" |
|
|
|
Check for valid shuffle, sampler, num_shards, and shard_id inputs. |
|
|
|
Args: |
|
|
|
param_dict (dict): param_dict |
|
|
|
param_dict (dict): param_dict. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: ValueError or RuntimeError if error |
|
|
|
Exception: ValueError or RuntimeError if error. |
|
|
|
""" |
|
|
|
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') |
|
|
|
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') |
|
|
|
@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict): |
|
|
|
|
|
|
|
def check_padding_options(param_dict): |
|
|
|
""" |
|
|
|
Check for valid padded_sample and num_padded of padded samples |
|
|
|
Check for valid padded_sample and num_padded of padded samples. |
|
|
|
|
|
|
|
Args: |
|
|
|
param_dict (dict): param_dict |
|
|
|
param_dict (dict): param_dict. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: ValueError or RuntimeError if error |
|
|
|
Exception: ValueError or RuntimeError if error. |
|
|
|
""" |
|
|
|
|
|
|
|
columns_list = param_dict.get('columns_list') |
|
|
|
@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name): |
|
|
|
Check if the input parameter is list or numpy.ndarray. |
|
|
|
|
|
|
|
Args: |
|
|
|
param (list, nd.ndarray): param |
|
|
|
param_name (str): param_name |
|
|
|
param (list, nd.ndarray): param. |
|
|
|
param_name (str): param_name. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Exception: TypeError if error |
|
|
|
Exception: TypeError if error. |
|
|
|
""" |
|
|
|
|
|
|
|
type_check(param, (list, np.ndarray), param_name) |
|
|
|
|