You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 33 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in validators.
  16. """
  17. import inspect as ins
  18. import os
  19. from functools import wraps
  20. from multiprocessing import cpu_count
  21. from mindspore._c_expression import typing
  22. from . import samplers
  23. from . import datasets
  24. INT32_MAX = 2147483647
  25. valid_detype = [
  26. "bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
  27. "uint32", "uint64", "float16", "float32", "float64"
  28. ]
  29. def check(method):
  30. """Check the function parameters and return the function ."""
  31. func_name = method.__name__
  32. # Required parameter
  33. req_param_int = []
  34. req_param_bool = []
  35. # Non-required parameter
  36. nreq_param_int = []
  37. nreq_param_bool = []
  38. if func_name in 'repeat':
  39. nreq_param_int = ['count', 'prefetch_size']
  40. if func_name in 'take':
  41. req_param_int = ['count']
  42. nreq_param_int = ['prefetch_size']
  43. elif func_name in 'shuffle':
  44. req_param_int = ['buffer_size']
  45. nreq_param_bool = ['reshuffle_each_iteration']
  46. nreq_param_int = ['prefetch_size', 'seed']
  47. elif func_name in 'batch':
  48. req_param_int = ['batch_size']
  49. nreq_param_int = ['num_parallel_workers', 'prefetch_size']
  50. nreq_param_bool = ['drop_remainder']
  51. elif func_name in ('zip', 'filter', 'cache', 'rename', 'project'):
  52. nreq_param_int = ['prefetch_size']
  53. elif func_name in ('map', '__init__'):
  54. nreq_param_int = ['num_parallel_workers', 'prefetch_size', 'seed']
  55. nreq_param_bool = ['block_reader']
  56. @wraps(method)
  57. def wrapper(*args, **kwargs):
  58. def _make_key():
  59. sig = ins.signature(method)
  60. params = sig.parameters
  61. keys = list(params.keys())
  62. param_dic = dict()
  63. for name, value in enumerate(args):
  64. param_dic[keys[name]] = value
  65. param_dic.update(zip(params.keys(), args))
  66. param_dic.update(kwargs)
  67. for name, value in params.items():
  68. if name not in param_dic:
  69. param_dic[name] = value.default
  70. return param_dic
  71. # check type
  72. def _check_param_type(arg, param_name, param_type=None):
  73. if param_type is not None and not isinstance(arg, param_type):
  74. raise ValueError(
  75. "The %s function %s type error!" % (func_name, param_name))
  76. # check range
  77. def _check_param_range(arg, param_name):
  78. if isinstance(arg, int) and param_name == "seed" and (
  79. arg < 0 or arg > 2147483647):
  80. raise ValueError(
  81. "The %s function %s exceeds the boundary!" % (
  82. func_name, param_name))
  83. if isinstance(arg, int) and param_name == "count" and ((arg <= 0 and arg != -1) or arg > 2147483647):
  84. raise ValueError(
  85. "The %s function %s exceeds the boundary!" % (
  86. func_name, param_name))
  87. if isinstance(arg, int) and param_name == "prefetch_size" and (
  88. arg <= 0 or arg > 1024):
  89. raise ValueError(
  90. "The %s function %s exceeds the boundary!" % (
  91. func_name, param_name))
  92. if isinstance(arg, int) and param_name == "num_parallel_workers" and (
  93. arg < 1 or arg > cpu_count()):
  94. raise ValueError(
  95. "The %s function %s exceeds the boundary(%s)!" % (
  96. func_name, param_name, cpu_count()))
  97. if isinstance(arg, int) and param_name != "seed" \
  98. and param_name != "count" and param_name != "prefetch_size" \
  99. and param_name != "num_parallel_workers" and (arg < 1 or arg > 2147483647):
  100. raise ValueError(
  101. "The %s function %s exceeds the boundary!" % (
  102. func_name, param_name))
  103. key = _make_key()
  104. # check integer
  105. for karg in req_param_int:
  106. _check_param_type(key[karg], karg, int)
  107. _check_param_range(key[karg], karg)
  108. for karg in nreq_param_int:
  109. if karg in key:
  110. if key[karg] is not None:
  111. _check_param_type(key[karg], karg, int)
  112. _check_param_range(key[karg], karg)
  113. # check bool
  114. for karg in req_param_bool:
  115. _check_param_type(key[karg], karg, bool)
  116. for karg in nreq_param_bool:
  117. if karg in key:
  118. if key[karg] is not None:
  119. _check_param_type(key[karg], karg, bool)
  120. if func_name in '__init__':
  121. if 'columns_list' in key.keys():
  122. columns_list = key['columns_list']
  123. if columns_list is not None:
  124. _check_param_type(columns_list, 'columns_list', list)
  125. if 'columns' in key.keys():
  126. columns = key['columns']
  127. if columns is not None:
  128. _check_param_type(columns, 'columns', list)
  129. if 'partitions' in key.keys():
  130. partitions = key['partitions']
  131. if partitions is not None:
  132. _check_param_type(partitions, 'partitions', list)
  133. if 'schema' in key.keys():
  134. schema = key['schema']
  135. if schema is not None:
  136. check_filename(schema)
  137. if not os.path.isfile(schema) or not os.access(schema, os.R_OK):
  138. raise ValueError(
  139. "The file %s does not exist or permission denied!" % schema)
  140. if 'dataset_dir' in key.keys():
  141. dataset_dir = key['dataset_dir']
  142. if dataset_dir is not None:
  143. if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
  144. raise ValueError(
  145. "The folder %s does not exist or permission denied!" % dataset_dir)
  146. if 'dataset_files' in key.keys():
  147. dataset_files = key['dataset_files']
  148. if not dataset_files:
  149. raise ValueError(
  150. "The dataset file does not exists!")
  151. if dataset_files is not None:
  152. _check_param_type(dataset_files, 'dataset_files', list)
  153. for file in dataset_files:
  154. if not os.path.isfile(file) or not os.access(file, os.R_OK):
  155. raise ValueError(
  156. "The file %s does not exist or permission denied!" % file)
  157. if 'dataset_file' in key.keys():
  158. dataset_file = key['dataset_file']
  159. if not dataset_file:
  160. raise ValueError(
  161. "The dataset file does not exists!")
  162. check_filename(dataset_file)
  163. if dataset_file is not None:
  164. if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
  165. raise ValueError(
  166. "The file %s does not exist or permission denied!" % dataset_file)
  167. return method(*args, **kwargs)
  168. return wrapper
  169. def check_valid_detype(type_):
  170. if type_ not in valid_detype:
  171. raise ValueError("Unknown column type")
  172. return True
  173. def check_filename(path):
  174. """
  175. check the filename in the path
  176. Args:
  177. path (str): the path
  178. Returns:
  179. Exception: when error
  180. """
  181. if not isinstance(path, str):
  182. raise ValueError("path: {} is not string".format(path))
  183. filename = os.path.basename(path)
  184. # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
  185. # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
  186. # '*', '(', '%', ')', '-', '=', '{', '?', '$'
  187. forbidden_symbols = set(r'\/:*?"<>|`&\';')
  188. if set(filename) & forbidden_symbols:
  189. raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
  190. if filename.startswith(' ') or filename.endswith(' '):
  191. raise ValueError("filename should not start/end with space")
  192. return True
  193. def make_param_dict(method, args, kwargs):
  194. """Return a dictionary of the method's args and kwargs."""
  195. sig = ins.signature(method)
  196. params = sig.parameters
  197. keys = list(params.keys())
  198. param_dict = dict()
  199. try:
  200. for name, value in enumerate(args):
  201. param_dict[keys[name]] = value
  202. except IndexError:
  203. raise TypeError("{0}() expected {1} arguments, but {2} were given".format(
  204. method.__name__, len(keys) - 1, len(args) - 1))
  205. param_dict.update(zip(params.keys(), args))
  206. param_dict.update(kwargs)
  207. for name, value in params.items():
  208. if name not in param_dict:
  209. param_dict[name] = value.default
  210. return param_dict
  211. def check_type(param, param_name, valid_type):
  212. if (not isinstance(param, valid_type)) or (valid_type == int and isinstance(param, bool)):
  213. raise TypeError("Wrong input type for {0}, should be {1}, got {2}".format(param_name, valid_type, type(param)))
  214. def check_param_type(param_list, param_dict, param_type):
  215. for param_name in param_list:
  216. if param_dict.get(param_name) is not None:
  217. if param_name == 'num_parallel_workers':
  218. check_num_parallel_workers(param_dict.get(param_name))
  219. if param_name == 'num_samples':
  220. check_num_samples(param_dict.get(param_name))
  221. else:
  222. check_type(param_dict.get(param_name), param_name, param_type)
  223. def check_positive_int32(param, param_name):
  224. check_interval_closed(param, param_name, [1, INT32_MAX])
  225. def check_interval_closed(param, param_name, valid_range):
  226. if param < valid_range[0] or param > valid_range[1]:
  227. raise ValueError("The value of {0} exceeds the closed interval range {1}.".format(param_name, valid_range))
  228. def check_num_parallel_workers(value):
  229. check_type(value, 'num_parallel_workers', int)
  230. if value < 1 or value > cpu_count():
  231. raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
  232. def check_num_samples(value):
  233. check_type(value, 'num_samples', int)
  234. if value <= 0:
  235. raise ValueError("num_samples must be greater than 0!")
  236. def check_dataset_dir(dataset_dir):
  237. if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
  238. raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
  239. def check_dataset_file(dataset_file):
  240. check_filename(dataset_file)
  241. if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
  242. raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
  243. def check_sampler_shuffle_shard_options(param_dict):
  244. """check for valid shuffle, sampler, num_shards, and shard_id inputs."""
  245. shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
  246. num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
  247. if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
  248. raise ValueError("sampler is not a valid Sampler type.")
  249. if sampler is not None:
  250. if shuffle is not None:
  251. raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
  252. if num_shards is not None:
  253. raise RuntimeError("sampler and sharding cannot be specified at the same time.")
  254. if num_shards is not None:
  255. if shard_id is None:
  256. raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
  257. if shard_id < 0 or shard_id >= num_shards:
  258. raise ValueError("shard_id is invalid, shard_id={}".format(shard_id))
  259. if num_shards is None and shard_id is not None:
  260. raise RuntimeError("shard_id is specified but num_shards is not.")
  261. def check_imagefolderdatasetv2(method):
  262. """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
  263. @wraps(method)
  264. def new_method(*args, **kwargs):
  265. param_dict = make_param_dict(method, args, kwargs)
  266. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  267. nreq_param_bool = ['shuffle', 'decode']
  268. nreq_param_list = ['extensions']
  269. nreq_param_dict = ['class_indexing']
  270. # check dataset_dir; required argument
  271. dataset_dir = param_dict.get('dataset_dir')
  272. if dataset_dir is None:
  273. raise ValueError("dataset_dir is not provided.")
  274. check_dataset_dir(dataset_dir)
  275. check_param_type(nreq_param_int, param_dict, int)
  276. check_param_type(nreq_param_bool, param_dict, bool)
  277. check_param_type(nreq_param_list, param_dict, list)
  278. check_param_type(nreq_param_dict, param_dict, dict)
  279. check_sampler_shuffle_shard_options(param_dict)
  280. return method(*args, **kwargs)
  281. return new_method
  282. def check_mnist_cifar_dataset(method):
  283. """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
  284. @wraps(method)
  285. def new_method(*args, **kwargs):
  286. param_dict = make_param_dict(method, args, kwargs)
  287. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  288. nreq_param_bool = ['shuffle']
  289. # check dataset_dir; required argument
  290. dataset_dir = param_dict.get('dataset_dir')
  291. if dataset_dir is None:
  292. raise ValueError("dataset_dir is not provided.")
  293. check_dataset_dir(dataset_dir)
  294. check_param_type(nreq_param_int, param_dict, int)
  295. check_param_type(nreq_param_bool, param_dict, bool)
  296. check_sampler_shuffle_shard_options(param_dict)
  297. return method(*args, **kwargs)
  298. return new_method
  299. def check_manifestdataset(method):
  300. """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
  301. @wraps(method)
  302. def new_method(*args, **kwargs):
  303. param_dict = make_param_dict(method, args, kwargs)
  304. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  305. nreq_param_bool = ['shuffle', 'decode']
  306. nreq_param_str = ['usage']
  307. nreq_param_dict = ['class_indexing']
  308. # check dataset_file; required argument
  309. dataset_file = param_dict.get('dataset_file')
  310. if dataset_file is None:
  311. raise ValueError("dataset_file is not provided.")
  312. check_dataset_file(dataset_file)
  313. check_param_type(nreq_param_int, param_dict, int)
  314. check_param_type(nreq_param_bool, param_dict, bool)
  315. check_param_type(nreq_param_str, param_dict, str)
  316. check_param_type(nreq_param_dict, param_dict, dict)
  317. check_sampler_shuffle_shard_options(param_dict)
  318. return method(*args, **kwargs)
  319. return new_method
  320. def check_tfrecorddataset(method):
  321. """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
  322. @wraps(method)
  323. def new_method(*args, **kwargs):
  324. param_dict = make_param_dict(method, args, kwargs)
  325. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  326. nreq_param_list = ['columns_list']
  327. nreq_param_bool = ['shard_equal_rows']
  328. # check dataset_files; required argument
  329. dataset_files = param_dict.get('dataset_files')
  330. if dataset_files is None:
  331. raise ValueError("dataset_files is not provided.")
  332. if not isinstance(dataset_files, (str, list)):
  333. raise TypeError("dataset_files should be of type str or a list of strings.")
  334. check_param_type(nreq_param_int, param_dict, int)
  335. check_param_type(nreq_param_list, param_dict, list)
  336. check_param_type(nreq_param_bool, param_dict, bool)
  337. check_sampler_shuffle_shard_options(param_dict)
  338. return method(*args, **kwargs)
  339. return new_method
  340. def check_vocdataset(method):
  341. """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
  342. @wraps(method)
  343. def new_method(*args, **kwargs):
  344. param_dict = make_param_dict(method, args, kwargs)
  345. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  346. nreq_param_bool = ['shuffle', 'decode']
  347. # check dataset_dir; required argument
  348. dataset_dir = param_dict.get('dataset_dir')
  349. if dataset_dir is None:
  350. raise ValueError("dataset_dir is not provided.")
  351. check_dataset_dir(dataset_dir)
  352. check_param_type(nreq_param_int, param_dict, int)
  353. check_param_type(nreq_param_bool, param_dict, bool)
  354. check_sampler_shuffle_shard_options(param_dict)
  355. return method(*args, **kwargs)
  356. return new_method
  357. def check_celebadataset(method):
  358. """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
  359. @wraps(method)
  360. def new_method(*args, **kwargs):
  361. param_dict = make_param_dict(method, args, kwargs)
  362. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  363. nreq_param_bool = ['shuffle', 'decode']
  364. nreq_param_list = ['extensions']
  365. nreq_param_str = ['dataset_type']
  366. # check dataset_dir; required argument
  367. dataset_dir = param_dict.get('dataset_dir')
  368. if dataset_dir is None:
  369. raise ValueError("dataset_dir is not provided.")
  370. check_dataset_dir(dataset_dir)
  371. check_param_type(nreq_param_int, param_dict, int)
  372. check_param_type(nreq_param_bool, param_dict, bool)
  373. check_param_type(nreq_param_list, param_dict, list)
  374. check_param_type(nreq_param_str, param_dict, str)
  375. dataset_type = param_dict.get('dataset_type')
  376. if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
  377. raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
  378. check_sampler_shuffle_shard_options(param_dict)
  379. sampler = param_dict.get('sampler')
  380. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  381. raise ValueError("CelebADataset does not support PKSampler.")
  382. return method(*args, **kwargs)
  383. return new_method
  384. def check_minddataset(method):
  385. """A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
  386. @wraps(method)
  387. def new_method(*args, **kwargs):
  388. param_dict = make_param_dict(method, args, kwargs)
  389. nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id']
  390. nreq_param_list = ['columns_list']
  391. nreq_param_bool = ['block_reader']
  392. # check dataset_file; required argument
  393. dataset_file = param_dict.get('dataset_file')
  394. if dataset_file is None:
  395. raise ValueError("dataset_file is not provided.")
  396. check_dataset_file(dataset_file)
  397. check_param_type(nreq_param_int, param_dict, int)
  398. check_param_type(nreq_param_list, param_dict, list)
  399. check_param_type(nreq_param_bool, param_dict, bool)
  400. num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
  401. if (num_shards is not None and shard_id is None) or (num_shards is None and shard_id is not None):
  402. raise ValueError("num_shards and shard_id need to be set or not set at the same time")
  403. return method(*args, **kwargs)
  404. return new_method
  405. def check_generatordataset(method):
  406. """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
  407. @wraps(method)
  408. def new_method(*args, **kwargs):
  409. param_dict = make_param_dict(method, args, kwargs)
  410. # check generator_function; required argument
  411. source = param_dict.get('source')
  412. if source is None:
  413. raise ValueError("source is not provided.")
  414. if not callable(source):
  415. try:
  416. iter(source)
  417. except TypeError:
  418. raise TypeError("source should be callable, iterable or random accessible")
  419. # check column_names; required argument
  420. column_names = param_dict.get('column_names')
  421. if column_names is None:
  422. raise ValueError("column_names is not provided.")
  423. # check optional argument
  424. nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
  425. check_param_type(nreq_param_int, param_dict, int)
  426. nreq_param_list = ["column_types"]
  427. check_param_type(nreq_param_list, param_dict, list)
  428. num_shards = param_dict.get("num_shards")
  429. shard_id = param_dict.get("shard_id")
  430. if (num_shards is None) != (shard_id is None):
  431. # These two parameters appear together.
  432. raise ValueError("num_shards and shard_id need to be passed in together")
  433. if num_shards is not None:
  434. if shard_id >= num_shards:
  435. raise ValueError("shard_id should be less than num_shards")
  436. sampler = param_dict.get("sampler")
  437. if sampler is not None:
  438. if isinstance(sampler, samplers.PKSampler):
  439. raise ValueError("PKSampler is not supported by GeneratorDataset")
  440. if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  441. samplers.RandomSampler, samplers.SubsetRandomSampler,
  442. samplers.WeightedRandomSampler, samplers.Sampler)):
  443. try:
  444. iter(sampler)
  445. except TypeError:
  446. raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
  447. return method(*args, **kwargs)
  448. return new_method
  449. def check_batch_size(batch_size):
  450. if not (isinstance(batch_size, int) or (callable(batch_size))):
  451. raise ValueError("batch_size should either be an int or a callable.")
  452. if callable(batch_size):
  453. sig = ins.signature(batch_size)
  454. if len(sig.parameters) != 1:
  455. raise ValueError("batch_size callable should take one parameter (BatchInfo).")
  456. def check_count(count):
  457. check_type(count, 'count', int)
  458. if (count <= 0 and count != -1) or count > INT32_MAX:
  459. raise ValueError("count should be either -1 or positive integer.")
  460. def check_columns(columns, name):
  461. if isinstance(columns, list):
  462. for column in columns:
  463. if not isinstance(column, str):
  464. raise TypeError("Each column in {0} should be of type str. Got {1}.".format(name, type(column)))
  465. elif not isinstance(columns, str):
  466. raise TypeError("{} should be either a list of strings or a single string.".format(name))
  467. def check_batch(method):
  468. """check the input arguments of batch."""
  469. @wraps(method)
  470. def new_method(*args, **kwargs):
  471. param_dict = make_param_dict(method, args, kwargs)
  472. nreq_param_int = ['num_parallel_workers']
  473. nreq_param_bool = ['drop_remainder']
  474. nreq_param_columns = ['input_columns']
  475. # check batch_size; required argument
  476. batch_size = param_dict.get("batch_size")
  477. if batch_size is None:
  478. raise ValueError("batch_size is not provided.")
  479. check_batch_size(batch_size)
  480. check_param_type(nreq_param_int, param_dict, int)
  481. check_param_type(nreq_param_bool, param_dict, bool)
  482. for param_name in nreq_param_columns:
  483. param = param_dict.get(param_name)
  484. if param is not None:
  485. check_columns(param, param_name)
  486. per_batch_map, input_columns = param_dict.get('per_batch_map'), param_dict.get('input_columns')
  487. if (per_batch_map is None) != (input_columns is None):
  488. # These two parameters appear together.
  489. raise ValueError("per_batch_map and input_columns need to be passed in together.")
  490. if input_columns is not None:
  491. if not input_columns: # Check whether input_columns is empty.
  492. raise ValueError("input_columns can not be empty")
  493. if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
  494. raise ValueError("the signature of per_batch_map should match with input columns")
  495. return method(*args, **kwargs)
  496. return new_method
  497. def check_sync_wait(method):
  498. """check the input arguments of sync_wait."""
  499. @wraps(method)
  500. def new_method(*args, **kwargs):
  501. param_dict = make_param_dict(method, args, kwargs)
  502. nreq_param_str = ['condition_name']
  503. nreq_param_int = ['step_size']
  504. check_param_type(nreq_param_int, param_dict, int)
  505. check_param_type(nreq_param_str, param_dict, str)
  506. return method(*args, **kwargs)
  507. return new_method
  508. def check_shuffle(method):
  509. """check the input arguments of shuffle."""
  510. @wraps(method)
  511. def new_method(*args, **kwargs):
  512. param_dict = make_param_dict(method, args, kwargs)
  513. # check buffer_size; required argument
  514. buffer_size = param_dict.get("buffer_size")
  515. if buffer_size is None:
  516. raise ValueError("buffer_size is not provided.")
  517. check_type(buffer_size, 'buffer_size', int)
  518. check_interval_closed(buffer_size, 'buffer_size', [2, INT32_MAX])
  519. return method(*args, **kwargs)
  520. return new_method
  521. def check_map(method):
  522. """check the input arguments of map."""
  523. @wraps(method)
  524. def new_method(*args, **kwargs):
  525. param_dict = make_param_dict(method, args, kwargs)
  526. nreq_param_list = ['columns_order']
  527. nreq_param_int = ['num_parallel_workers']
  528. nreq_param_columns = ['input_columns', 'output_columns']
  529. nreq_param_bool = ['python_multiprocessing']
  530. check_param_type(nreq_param_list, param_dict, list)
  531. check_param_type(nreq_param_int, param_dict, int)
  532. check_param_type(nreq_param_bool, param_dict, bool)
  533. for param_name in nreq_param_columns:
  534. param = param_dict.get(param_name)
  535. if param is not None:
  536. check_columns(param, param_name)
  537. return method(*args, **kwargs)
  538. return new_method
  539. def check_filter(method):
  540. """"check the input arguments of filter."""
  541. @wraps(method)
  542. def new_method(*args, **kwargs):
  543. param_dict = make_param_dict(method, args, kwargs)
  544. predicate = param_dict.get("predicate")
  545. if not callable(predicate):
  546. raise ValueError("Predicate should be a python function or a callable python object.")
  547. nreq_param_int = ['num_parallel_workers']
  548. check_param_type(nreq_param_int, param_dict, int)
  549. param_name = "input_columns"
  550. param = param_dict.get(param_name)
  551. if param is not None:
  552. check_columns(param, param_name)
  553. return method(*args, **kwargs)
  554. return new_method
  555. def check_repeat(method):
  556. """check the input arguments of repeat."""
  557. @wraps(method)
  558. def new_method(*args, **kwargs):
  559. param_dict = make_param_dict(method, args, kwargs)
  560. count = param_dict.get('count')
  561. if count is not None:
  562. check_count(count)
  563. return method(*args, **kwargs)
  564. return new_method
  565. def check_skip(method):
  566. """check the input arguments of skip."""
  567. @wraps(method)
  568. def new_method(*args, **kwargs):
  569. param_dict = make_param_dict(method, args, kwargs)
  570. count = param_dict.get('count')
  571. check_type(count, 'count', int)
  572. if count < 0:
  573. raise ValueError("Skip count must be positive integer or 0.")
  574. return method(*args, **kwargs)
  575. return new_method
  576. def check_take(method):
  577. """check the input arguments of take."""
  578. @wraps(method)
  579. def new_method(*args, **kwargs):
  580. param_dict = make_param_dict(method, args, kwargs)
  581. count = param_dict.get('count')
  582. check_count(count)
  583. return method(*args, **kwargs)
  584. return new_method
  585. def check_zip(method):
  586. """check the input arguments of zip."""
  587. @wraps(method)
  588. def new_method(*args, **kwargs):
  589. param_dict = make_param_dict(method, args, kwargs)
  590. # check datasets; required argument
  591. ds = param_dict.get("datasets")
  592. if ds is None:
  593. raise ValueError("datasets is not provided.")
  594. check_type(ds, 'datasets', tuple)
  595. return method(*args, **kwargs)
  596. return new_method
  597. def check_zip_dataset(method):
  598. """check the input arguments of zip method in `Dataset`."""
  599. @wraps(method)
  600. def new_method(*args, **kwargs):
  601. param_dict = make_param_dict(method, args, kwargs)
  602. # check datasets; required argument
  603. ds = param_dict.get("datasets")
  604. if ds is None:
  605. raise ValueError("datasets is not provided.")
  606. if not isinstance(ds, (tuple, datasets.Dataset)):
  607. raise ValueError("datasets is not tuple or of type Dataset.")
  608. return method(*args, **kwargs)
  609. return new_method
  610. def check_rename(method):
  611. """check the input arguments of rename."""
  612. @wraps(method)
  613. def new_method(*args, **kwargs):
  614. param_dict = make_param_dict(method, args, kwargs)
  615. req_param_columns = ['input_columns', 'output_columns']
  616. # check req_param_list; required arguments
  617. for param_name in req_param_columns:
  618. param = param_dict.get(param_name)
  619. if param is None:
  620. raise ValueError("{} is not provided.".format(param_name))
  621. check_columns(param, param_name)
  622. return method(*args, **kwargs)
  623. return new_method
  624. def check_project(method):
  625. """check the input arguments of project."""
  626. @wraps(method)
  627. def new_method(*args, **kwargs):
  628. param_dict = make_param_dict(method, args, kwargs)
  629. # check columns; required argument
  630. columns = param_dict.get("columns")
  631. if columns is None:
  632. raise ValueError("columns is not provided.")
  633. check_columns(columns, 'columns')
  634. return method(*args, **kwargs)
  635. return new_method
  636. def check_shape(shape, name):
  637. if isinstance(shape, list):
  638. for element in shape:
  639. if not isinstance(element, int):
  640. raise TypeError(
  641. "Each element in {0} should be of type int. Got {1}.".format(name, type(element)))
  642. else:
  643. raise TypeError("Expected int list.")
  644. def check_add_column(method):
  645. """check the input arguments of add_column."""
  646. @wraps(method)
  647. def new_method(*args, **kwargs):
  648. param_dict = make_param_dict(method, args, kwargs)
  649. # check name; required argument
  650. name = param_dict.get("name")
  651. if not isinstance(name, str) or not name:
  652. raise TypeError("Expected non-empty string.")
  653. # check type; required argument
  654. de_type = param_dict.get("de_type")
  655. if de_type is not None:
  656. if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
  657. raise ValueError("Unknown column type.")
  658. else:
  659. raise TypeError("Expected non-empty string.")
  660. # check shape
  661. shape = param_dict.get("shape")
  662. if shape is not None:
  663. check_shape(shape, "shape")
  664. return method(*args, **kwargs)
  665. return new_method
  666. def check_textfiledataset(method):
  667. """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
  668. @wraps(method)
  669. def new_method(*args, **kwargs):
  670. param_dict = make_param_dict(method, args, kwargs)
  671. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  672. # check dataset_files; required argument
  673. dataset_files = param_dict.get('dataset_files')
  674. if dataset_files is None:
  675. raise ValueError("dataset_files is not provided.")
  676. if not isinstance(dataset_files, (str, list)):
  677. raise TypeError("dataset_files should be of type str or a list of strings.")
  678. check_param_type(nreq_param_int, param_dict, int)
  679. check_sampler_shuffle_shard_options(param_dict)
  680. return method(*args, **kwargs)
  681. return new_method