|
|
|
@@ -398,6 +398,7 @@ def check_tfrecorddataset(method): |
|
|
|
|
|
|
|
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] |
|
|
|
nreq_param_list = ['columns_list'] |
|
|
|
nreq_param_bool = ['shard_equal_rows'] |
|
|
|
|
|
|
|
# check dataset_files; required argument |
|
|
|
dataset_files = param_dict.get('dataset_files') |
|
|
|
@@ -410,6 +411,10 @@ def check_tfrecorddataset(method): |
|
|
|
|
|
|
|
check_param_type(nreq_param_list, param_dict, list) |
|
|
|
|
|
|
|
check_param_type(nreq_param_bool, param_dict, bool) |
|
|
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict) |
|
|
|
|
|
|
|
return method(*args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
|