| @@ -398,6 +398,7 @@ def check_tfrecorddataset(method): | |||||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | ||||
| nreq_param_list = ['columns_list'] | nreq_param_list = ['columns_list'] | ||||
| nreq_param_bool = ['shard_equal_rows'] | |||||
| # check dataset_files; required argument | # check dataset_files; required argument | ||||
| dataset_files = param_dict.get('dataset_files') | dataset_files = param_dict.get('dataset_files') | ||||
| @@ -410,6 +411,10 @@ def check_tfrecorddataset(method): | |||||
| check_param_type(nreq_param_list, param_dict, list) | check_param_type(nreq_param_list, param_dict, list) | ||||
| check_param_type(nreq_param_bool, param_dict, bool) | |||||
| check_sampler_shuffle_shard_options(param_dict) | |||||
| return method(*args, **kwargs) | return method(*args, **kwargs) | ||||
| return new_method | return new_method | ||||