|
|
|
@@ -31,7 +31,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis |
|
|
|
|
|
|
|
from . import datasets |
|
|
|
from . import samplers |
|
|
|
from . import cache_client |
|
|
|
# from . import cache_client |
|
|
|
from .. import callback |
|
|
|
|
|
|
|
|
|
|
|
@@ -56,6 +56,9 @@ def check_imagefolderdataset(method): |
|
|
|
validate_dataset_param_value(nreq_param_dict, param_dict, dict) |
|
|
|
check_sampler_shuffle_shard_options(param_dict) |
|
|
|
|
|
|
|
cache = param_dict.get('cache') |
|
|
|
check_cache_option(cache) |
|
|
|
|
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
@@ -136,6 +139,9 @@ def check_tfrecorddataset(method): |
|
|
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict) |
|
|
|
|
|
|
|
cache = param_dict.get('cache') |
|
|
|
check_cache_option(cache) |
|
|
|
|
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
@@ -389,6 +395,9 @@ def check_random_dataset(method): |
|
|
|
|
|
|
|
check_sampler_shuffle_shard_options(param_dict) |
|
|
|
|
|
|
|
cache = param_dict.get('cache') |
|
|
|
check_cache_option(cache) |
|
|
|
|
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
@@ -572,8 +581,7 @@ def check_map(method): |
|
|
|
if num_parallel_workers is not None: |
|
|
|
check_num_parallel_workers(num_parallel_workers) |
|
|
|
type_check(python_multiprocessing, (bool,), "python_multiprocessing") |
|
|
|
if cache is not None: |
|
|
|
type_check(cache, (cache_client.DatasetCache,), "cache") |
|
|
|
check_cache_option(cache) |
|
|
|
|
|
|
|
if callbacks is not None: |
|
|
|
if isinstance(callbacks, (list, tuple)): |
|
|
|
@@ -1215,3 +1223,11 @@ def check_paddeddataset(method): |
|
|
|
return method(self, *args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
|
|
|
|
|
|
|
|
def check_cache_option(cache): |
|
|
|
"""Sanity check for cache parameter""" |
|
|
|
if cache is not None: |
|
|
|
# temporary disable cache feature in the current release |
|
|
|
# type_check(cache, (cache_client.DatasetCache,), "cache") |
|
|
|
raise ValueError("Caching is disabled in the current release") |