You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 50 kB

6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in validators.
  16. """
  17. import inspect as ins
  18. import os
  19. from functools import wraps
  20. from multiprocessing import cpu_count
  21. import numpy as np
  22. from mindspore._c_expression import typing
  23. from . import datasets
  24. from . import samplers
  25. INT32_MAX = 2147483647
  26. valid_detype = [
  27. "bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
  28. "uint32", "uint64", "float16", "float32", "float64", "string"
  29. ]
  30. def check_valid_detype(type_):
  31. if type_ not in valid_detype:
  32. raise ValueError("Unknown column type")
  33. return True
  34. def check_filename(path):
  35. """
  36. check the filename in the path
  37. Args:
  38. path (str): the path
  39. Returns:
  40. Exception: when error
  41. """
  42. if not isinstance(path, str):
  43. raise TypeError("path: {} is not string".format(path))
  44. filename = os.path.basename(path)
  45. # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
  46. # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
  47. # '*', '(', '%', ')', '-', '=', '{', '?', '$'
  48. forbidden_symbols = set(r'\/:*?"<>|`&\';')
  49. if set(filename) & forbidden_symbols:
  50. raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'")
  51. if filename.startswith(' ') or filename.endswith(' '):
  52. raise ValueError("filename should not start/end with space")
  53. return True
  54. def make_param_dict(method, args, kwargs):
  55. """Return a dictionary of the method's args and kwargs."""
  56. sig = ins.signature(method)
  57. params = sig.parameters
  58. keys = list(params.keys())
  59. param_dict = dict()
  60. try:
  61. for name, value in enumerate(args):
  62. param_dict[keys[name]] = value
  63. except IndexError:
  64. raise TypeError("{0}() expected {1} arguments, but {2} were given".format(
  65. method.__name__, len(keys) - 1, len(args) - 1))
  66. param_dict.update(zip(params.keys(), args))
  67. param_dict.update(kwargs)
  68. for name, value in params.items():
  69. if name not in param_dict:
  70. param_dict[name] = value.default
  71. return param_dict
  72. def check_type(param, param_name, valid_type):
  73. if (not isinstance(param, valid_type)) or (valid_type == int and isinstance(param, bool)):
  74. raise TypeError("Wrong input type for {0}, should be {1}, got {2}".format(param_name, valid_type, type(param)))
  75. def check_param_type(param_list, param_dict, param_type):
  76. for param_name in param_list:
  77. if param_dict.get(param_name) is not None:
  78. if param_name == 'num_parallel_workers':
  79. check_num_parallel_workers(param_dict.get(param_name))
  80. if param_name == 'num_samples':
  81. check_num_samples(param_dict.get(param_name))
  82. else:
  83. check_type(param_dict.get(param_name), param_name, param_type)
  84. def check_positive_int32(param, param_name):
  85. check_interval_closed(param, param_name, [1, INT32_MAX])
  86. def check_interval_closed(param, param_name, valid_range):
  87. if param < valid_range[0] or param > valid_range[1]:
  88. raise ValueError("The value of {0} exceeds the closed interval range {1}.".format(param_name, valid_range))
  89. def check_num_parallel_workers(value):
  90. check_type(value, 'num_parallel_workers', int)
  91. if value < 1 or value > cpu_count():
  92. raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
  93. def check_num_samples(value):
  94. check_type(value, 'num_samples', int)
  95. if value < 0:
  96. raise ValueError("num_samples cannot be less than 0!")
  97. def check_dataset_dir(dataset_dir):
  98. if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
  99. raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir))
  100. def check_dataset_file(dataset_file):
  101. check_filename(dataset_file)
  102. if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
  103. raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
  104. def check_sampler_shuffle_shard_options(param_dict):
  105. """check for valid shuffle, sampler, num_shards, and shard_id inputs."""
  106. shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
  107. num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
  108. if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)):
  109. raise TypeError("sampler is not a valid Sampler type.")
  110. if sampler is not None:
  111. if shuffle is not None:
  112. raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
  113. if num_shards is not None:
  114. raise RuntimeError("sampler and sharding cannot be specified at the same time.")
  115. if num_shards is not None:
  116. check_positive_int32(num_shards, "num_shards")
  117. if shard_id is None:
  118. raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
  119. if shard_id < 0 or shard_id >= num_shards:
  120. raise ValueError("shard_id is invalid, shard_id={}".format(shard_id))
  121. if num_shards is None and shard_id is not None:
  122. raise RuntimeError("shard_id is specified but num_shards is not.")
  123. def check_padding_options(param_dict):
  124. """ check for valid padded_sample and num_padded of padded samples"""
  125. columns_list = param_dict.get('columns_list')
  126. block_reader = param_dict.get('block_reader')
  127. padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
  128. if padded_sample is not None:
  129. if num_padded is None:
  130. raise RuntimeError("padded_sample is specified and requires num_padded as well.")
  131. if num_padded < 0:
  132. raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
  133. if columns_list is None:
  134. raise RuntimeError("padded_sample is specified and requires columns_list as well.")
  135. for column in columns_list:
  136. if column not in padded_sample:
  137. raise ValueError("padded_sample cannot match columns_list.")
  138. if block_reader:
  139. raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.")
  140. if padded_sample is None and num_padded is not None:
  141. raise RuntimeError("num_padded is specified but padded_sample is not.")
  142. def check_imagefolderdatasetv2(method):
  143. """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
  144. @wraps(method)
  145. def new_method(*args, **kwargs):
  146. param_dict = make_param_dict(method, args, kwargs)
  147. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  148. nreq_param_bool = ['shuffle', 'decode']
  149. nreq_param_list = ['extensions']
  150. nreq_param_dict = ['class_indexing']
  151. # check dataset_dir; required argument
  152. dataset_dir = param_dict.get('dataset_dir')
  153. if dataset_dir is None:
  154. raise ValueError("dataset_dir is not provided.")
  155. check_dataset_dir(dataset_dir)
  156. check_param_type(nreq_param_int, param_dict, int)
  157. check_param_type(nreq_param_bool, param_dict, bool)
  158. check_param_type(nreq_param_list, param_dict, list)
  159. check_param_type(nreq_param_dict, param_dict, dict)
  160. check_sampler_shuffle_shard_options(param_dict)
  161. return method(*args, **kwargs)
  162. return new_method
  163. def check_mnist_cifar_dataset(method):
  164. """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
  165. @wraps(method)
  166. def new_method(*args, **kwargs):
  167. param_dict = make_param_dict(method, args, kwargs)
  168. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  169. nreq_param_bool = ['shuffle']
  170. # check dataset_dir; required argument
  171. dataset_dir = param_dict.get('dataset_dir')
  172. if dataset_dir is None:
  173. raise ValueError("dataset_dir is not provided.")
  174. check_dataset_dir(dataset_dir)
  175. check_param_type(nreq_param_int, param_dict, int)
  176. check_param_type(nreq_param_bool, param_dict, bool)
  177. check_sampler_shuffle_shard_options(param_dict)
  178. return method(*args, **kwargs)
  179. return new_method
  180. def check_manifestdataset(method):
  181. """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset)."""
  182. @wraps(method)
  183. def new_method(*args, **kwargs):
  184. param_dict = make_param_dict(method, args, kwargs)
  185. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  186. nreq_param_bool = ['shuffle', 'decode']
  187. nreq_param_str = ['usage']
  188. nreq_param_dict = ['class_indexing']
  189. # check dataset_file; required argument
  190. dataset_file = param_dict.get('dataset_file')
  191. if dataset_file is None:
  192. raise ValueError("dataset_file is not provided.")
  193. check_dataset_file(dataset_file)
  194. check_param_type(nreq_param_int, param_dict, int)
  195. check_param_type(nreq_param_bool, param_dict, bool)
  196. check_param_type(nreq_param_str, param_dict, str)
  197. check_param_type(nreq_param_dict, param_dict, dict)
  198. check_sampler_shuffle_shard_options(param_dict)
  199. return method(*args, **kwargs)
  200. return new_method
  201. def check_tfrecorddataset(method):
  202. """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset)."""
  203. @wraps(method)
  204. def new_method(*args, **kwargs):
  205. param_dict = make_param_dict(method, args, kwargs)
  206. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  207. nreq_param_list = ['columns_list']
  208. nreq_param_bool = ['shard_equal_rows']
  209. # check dataset_files; required argument
  210. dataset_files = param_dict.get('dataset_files')
  211. if dataset_files is None:
  212. raise ValueError("dataset_files is not provided.")
  213. if not isinstance(dataset_files, (str, list)):
  214. raise TypeError("dataset_files should be of type str or a list of strings.")
  215. check_param_type(nreq_param_int, param_dict, int)
  216. check_param_type(nreq_param_list, param_dict, list)
  217. check_param_type(nreq_param_bool, param_dict, bool)
  218. check_sampler_shuffle_shard_options(param_dict)
  219. return method(*args, **kwargs)
  220. return new_method
  221. def check_vocdataset(method):
  222. """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset)."""
  223. @wraps(method)
  224. def new_method(*args, **kwargs):
  225. param_dict = make_param_dict(method, args, kwargs)
  226. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  227. nreq_param_bool = ['shuffle', 'decode']
  228. nreq_param_dict = ['class_indexing']
  229. # check dataset_dir; required argument
  230. dataset_dir = param_dict.get('dataset_dir')
  231. if dataset_dir is None:
  232. raise ValueError("dataset_dir is not provided.")
  233. check_dataset_dir(dataset_dir)
  234. # check task; required argument
  235. task = param_dict.get('task')
  236. if task is None:
  237. raise ValueError("task is not provided.")
  238. if not isinstance(task, str):
  239. raise TypeError("task is not str type.")
  240. # check mode; required argument
  241. mode = param_dict.get('mode')
  242. if mode is None:
  243. raise ValueError("mode is not provided.")
  244. if not isinstance(mode, str):
  245. raise TypeError("mode is not str type.")
  246. imagesets_file = ""
  247. if task == "Segmentation":
  248. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
  249. if param_dict.get('class_indexing') is not None:
  250. raise ValueError("class_indexing is invalid in Segmentation task")
  251. elif task == "Detection":
  252. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
  253. else:
  254. raise ValueError("Invalid task : " + task)
  255. check_dataset_file(imagesets_file)
  256. check_param_type(nreq_param_int, param_dict, int)
  257. check_param_type(nreq_param_bool, param_dict, bool)
  258. check_param_type(nreq_param_dict, param_dict, dict)
  259. check_sampler_shuffle_shard_options(param_dict)
  260. return method(*args, **kwargs)
  261. return new_method
  262. def check_cocodataset(method):
  263. """A wrapper that wrap a parameter checker to the original Dataset(CocoDataset)."""
  264. @wraps(method)
  265. def new_method(*args, **kwargs):
  266. param_dict = make_param_dict(method, args, kwargs)
  267. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  268. nreq_param_bool = ['shuffle', 'decode']
  269. # check dataset_dir; required argument
  270. dataset_dir = param_dict.get('dataset_dir')
  271. if dataset_dir is None:
  272. raise ValueError("dataset_dir is not provided.")
  273. check_dataset_dir(dataset_dir)
  274. # check annotation_file; required argument
  275. annotation_file = param_dict.get('annotation_file')
  276. if annotation_file is None:
  277. raise ValueError("annotation_file is not provided.")
  278. check_dataset_file(annotation_file)
  279. # check task; required argument
  280. task = param_dict.get('task')
  281. if task is None:
  282. raise ValueError("task is not provided.")
  283. if not isinstance(task, str):
  284. raise TypeError("task is not str type.")
  285. if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
  286. raise ValueError("Invalid task type")
  287. check_param_type(nreq_param_int, param_dict, int)
  288. check_param_type(nreq_param_bool, param_dict, bool)
  289. sampler = param_dict.get('sampler')
  290. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  291. raise ValueError("CocoDataset doesn't support PKSampler")
  292. check_sampler_shuffle_shard_options(param_dict)
  293. return method(*args, **kwargs)
  294. return new_method
  295. def check_celebadataset(method):
  296. """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset)."""
  297. @wraps(method)
  298. def new_method(*args, **kwargs):
  299. param_dict = make_param_dict(method, args, kwargs)
  300. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  301. nreq_param_bool = ['shuffle', 'decode']
  302. nreq_param_list = ['extensions']
  303. nreq_param_str = ['dataset_type']
  304. # check dataset_dir; required argument
  305. dataset_dir = param_dict.get('dataset_dir')
  306. if dataset_dir is None:
  307. raise ValueError("dataset_dir is not provided.")
  308. check_dataset_dir(dataset_dir)
  309. check_param_type(nreq_param_int, param_dict, int)
  310. check_param_type(nreq_param_bool, param_dict, bool)
  311. check_param_type(nreq_param_list, param_dict, list)
  312. check_param_type(nreq_param_str, param_dict, str)
  313. dataset_type = param_dict.get('dataset_type')
  314. if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
  315. raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
  316. check_sampler_shuffle_shard_options(param_dict)
  317. sampler = param_dict.get('sampler')
  318. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  319. raise ValueError("CelebADataset does not support PKSampler.")
  320. return method(*args, **kwargs)
  321. return new_method
  322. def check_minddataset(method):
  323. """A wrapper that wrap a parameter checker to the original Dataset(MindDataset)."""
  324. @wraps(method)
  325. def new_method(*args, **kwargs):
  326. param_dict = make_param_dict(method, args, kwargs)
  327. nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
  328. nreq_param_list = ['columns_list']
  329. nreq_param_bool = ['block_reader']
  330. nreq_param_dict = ['padded_sample']
  331. # check dataset_file; required argument
  332. dataset_file = param_dict.get('dataset_file')
  333. if dataset_file is None:
  334. raise ValueError("dataset_file is not provided.")
  335. if isinstance(dataset_file, list):
  336. for f in dataset_file:
  337. check_dataset_file(f)
  338. else:
  339. check_dataset_file(dataset_file)
  340. check_param_type(nreq_param_int, param_dict, int)
  341. check_param_type(nreq_param_list, param_dict, list)
  342. check_param_type(nreq_param_bool, param_dict, bool)
  343. check_param_type(nreq_param_dict, param_dict, dict)
  344. check_sampler_shuffle_shard_options(param_dict)
  345. check_padding_options(param_dict)
  346. return method(*args, **kwargs)
  347. return new_method
  348. def check_generatordataset(method):
  349. """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset)."""
  350. @wraps(method)
  351. def new_method(*args, **kwargs):
  352. param_dict = make_param_dict(method, args, kwargs)
  353. # check generator_function; required argument
  354. source = param_dict.get('source')
  355. if source is None:
  356. raise ValueError("source is not provided.")
  357. if not callable(source):
  358. try:
  359. iter(source)
  360. except TypeError:
  361. raise TypeError("source should be callable, iterable or random accessible")
  362. # check column_names or schema; required argument
  363. column_names = param_dict.get('column_names')
  364. if column_names is not None:
  365. check_columns(column_names, "column_names")
  366. schema = param_dict.get('schema')
  367. if column_names is None and schema is None:
  368. raise ValueError("Neither columns_names not schema are provided.")
  369. if schema is not None:
  370. if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
  371. raise ValueError("schema should be a path to schema file or a schema object.")
  372. # check optional argument
  373. nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
  374. check_param_type(nreq_param_int, param_dict, int)
  375. nreq_param_list = ["column_types"]
  376. check_param_type(nreq_param_list, param_dict, list)
  377. nreq_param_bool = ["shuffle"]
  378. check_param_type(nreq_param_bool, param_dict, bool)
  379. num_shards = param_dict.get("num_shards")
  380. shard_id = param_dict.get("shard_id")
  381. if (num_shards is None) != (shard_id is None):
  382. # These two parameters appear together.
  383. raise ValueError("num_shards and shard_id need to be passed in together")
  384. if num_shards is not None:
  385. check_positive_int32(num_shards, "num_shards")
  386. if shard_id >= num_shards:
  387. raise ValueError("shard_id should be less than num_shards")
  388. sampler = param_dict.get("sampler")
  389. if sampler is not None:
  390. if isinstance(sampler, samplers.PKSampler):
  391. raise ValueError("PKSampler is not supported by GeneratorDataset")
  392. if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  393. samplers.RandomSampler, samplers.SubsetRandomSampler,
  394. samplers.WeightedRandomSampler, samplers.Sampler)):
  395. try:
  396. iter(sampler)
  397. except TypeError:
  398. raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
  399. if sampler is not None and not hasattr(source, "__getitem__"):
  400. raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
  401. if num_shards is not None and not hasattr(source, "__getitem__"):
  402. raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
  403. return method(*args, **kwargs)
  404. return new_method
  405. def check_batch_size(batch_size):
  406. if not (isinstance(batch_size, int) or (callable(batch_size))):
  407. raise TypeError("batch_size should either be an int or a callable.")
  408. if callable(batch_size):
  409. sig = ins.signature(batch_size)
  410. if len(sig.parameters) != 1:
  411. raise ValueError("batch_size callable should take one parameter (BatchInfo).")
  412. def check_count(count):
  413. check_type(count, 'count', int)
  414. if (count <= 0 and count != -1) or count > INT32_MAX:
  415. raise ValueError("count should be either -1 or positive integer.")
  416. def check_columns(columns, name):
  417. if isinstance(columns, list):
  418. for column in columns:
  419. if not isinstance(column, str):
  420. raise TypeError("Each column in {0} should be of type str. Got {1}.".format(name, type(column)))
  421. elif not isinstance(columns, str):
  422. raise TypeError("{} should be either a list of strings or a single string.".format(name))
  423. def check_pad_info(key, val):
  424. """check the key and value pair of pad_info in batch"""
  425. check_type(key, "key in pad_info", str)
  426. if val is not None:
  427. assert len(val) == 2, "value of pad_info should be a tuple of size 2"
  428. check_type(val, "value in pad_info", tuple)
  429. if val[0] is not None:
  430. check_type(val[0], "pad_shape", list)
  431. for dim in val[0]:
  432. if dim is not None:
  433. check_type(dim, "dim in pad_shape", int)
  434. assert dim > 0, "pad shape should be positive integers"
  435. if val[1] is not None:
  436. check_type(val[1], "pad_value", (int, float, str, bytes))
  437. def check_bucket_batch_by_length(method):
  438. """check the input arguments of bucket_batch_by_length."""
  439. @wraps(method)
  440. def new_method(*args, **kwargs):
  441. param_dict = make_param_dict(method, args, kwargs)
  442. nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
  443. check_param_type(nreq_param_list, param_dict, list)
  444. nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
  445. check_param_type(nbool_param_list, param_dict, bool)
  446. # check column_names: must be list of string.
  447. column_names = param_dict.get("column_names")
  448. if not column_names:
  449. raise ValueError("column_names cannot be empty")
  450. all_string = all(isinstance(item, str) for item in column_names)
  451. if not all_string:
  452. raise TypeError("column_names should be a list of str.")
  453. element_length_function = param_dict.get("element_length_function")
  454. if element_length_function is None and len(column_names) != 1:
  455. raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
  456. # check bucket_boundaries: must be list of int, positive and strictly increasing
  457. bucket_boundaries = param_dict.get('bucket_boundaries')
  458. if not bucket_boundaries:
  459. raise ValueError("bucket_boundaries cannot be empty.")
  460. all_int = all(isinstance(item, int) for item in bucket_boundaries)
  461. if not all_int:
  462. raise TypeError("bucket_boundaries should be a list of int.")
  463. all_non_negative = all(item >= 0 for item in bucket_boundaries)
  464. if not all_non_negative:
  465. raise ValueError("bucket_boundaries cannot contain any negative numbers.")
  466. for i in range(len(bucket_boundaries) - 1):
  467. if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
  468. raise ValueError("bucket_boundaries should be strictly increasing.")
  469. # check bucket_batch_sizes: must be list of int and positive
  470. bucket_batch_sizes = param_dict.get('bucket_batch_sizes')
  471. if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
  472. raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
  473. all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
  474. if not all_int:
  475. raise TypeError("bucket_batch_sizes should be a list of int.")
  476. all_non_negative = all(item > 0 for item in bucket_batch_sizes)
  477. if not all_non_negative:
  478. raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
  479. if param_dict.get('pad_info') is not None:
  480. check_type(param_dict["pad_info"], "pad_info", dict)
  481. for k, v in param_dict.get('pad_info').items():
  482. check_pad_info(k, v)
  483. return method(*args, **kwargs)
  484. return new_method
  485. def check_batch(method):
  486. """check the input arguments of batch."""
  487. @wraps(method)
  488. def new_method(*args, **kwargs):
  489. param_dict = make_param_dict(method, args, kwargs)
  490. nreq_param_int = ['num_parallel_workers']
  491. nreq_param_bool = ['drop_remainder']
  492. nreq_param_columns = ['input_columns']
  493. # check batch_size; required argument
  494. batch_size = param_dict.get("batch_size")
  495. if batch_size is None:
  496. raise ValueError("batch_size is not provided.")
  497. check_batch_size(batch_size)
  498. check_param_type(nreq_param_int, param_dict, int)
  499. check_param_type(nreq_param_bool, param_dict, bool)
  500. if (param_dict.get('pad_info') is not None) and (param_dict.get('per_batch_map') is not None):
  501. raise ValueError("pad_info and per_batch_map can't both be set")
  502. if param_dict.get('pad_info') is not None:
  503. check_type(param_dict["pad_info"], "pad_info", dict)
  504. for k, v in param_dict.get('pad_info').items():
  505. check_pad_info(k, v)
  506. for param_name in nreq_param_columns:
  507. param = param_dict.get(param_name)
  508. if param is not None:
  509. check_columns(param, param_name)
  510. per_batch_map, input_columns = param_dict.get('per_batch_map'), param_dict.get('input_columns')
  511. if (per_batch_map is None) != (input_columns is None):
  512. # These two parameters appear together.
  513. raise ValueError("per_batch_map and input_columns need to be passed in together.")
  514. if input_columns is not None:
  515. if not input_columns: # Check whether input_columns is empty.
  516. raise ValueError("input_columns can not be empty")
  517. if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
  518. raise ValueError("the signature of per_batch_map should match with input columns")
  519. return method(*args, **kwargs)
  520. return new_method
  521. def check_sync_wait(method):
  522. """check the input arguments of sync_wait."""
  523. @wraps(method)
  524. def new_method(*args, **kwargs):
  525. param_dict = make_param_dict(method, args, kwargs)
  526. nreq_param_str = ['condition_name']
  527. nreq_param_int = ['step_size']
  528. check_param_type(nreq_param_int, param_dict, int)
  529. check_param_type(nreq_param_str, param_dict, str)
  530. return method(*args, **kwargs)
  531. return new_method
  532. def check_shuffle(method):
  533. """check the input arguments of shuffle."""
  534. @wraps(method)
  535. def new_method(*args, **kwargs):
  536. param_dict = make_param_dict(method, args, kwargs)
  537. # check buffer_size; required argument
  538. buffer_size = param_dict.get("buffer_size")
  539. if buffer_size is None:
  540. raise ValueError("buffer_size is not provided.")
  541. check_type(buffer_size, 'buffer_size', int)
  542. check_interval_closed(buffer_size, 'buffer_size', [2, INT32_MAX])
  543. return method(*args, **kwargs)
  544. return new_method
  545. def check_map(method):
  546. """check the input arguments of map."""
  547. @wraps(method)
  548. def new_method(*args, **kwargs):
  549. param_dict = make_param_dict(method, args, kwargs)
  550. nreq_param_list = ['columns_order']
  551. nreq_param_int = ['num_parallel_workers']
  552. nreq_param_columns = ['input_columns', 'output_columns']
  553. nreq_param_bool = ['python_multiprocessing']
  554. check_param_type(nreq_param_list, param_dict, list)
  555. check_param_type(nreq_param_int, param_dict, int)
  556. check_param_type(nreq_param_bool, param_dict, bool)
  557. for param_name in nreq_param_columns:
  558. param = param_dict.get(param_name)
  559. if param is not None:
  560. check_columns(param, param_name)
  561. return method(*args, **kwargs)
  562. return new_method
  563. def check_filter(method):
  564. """"check the input arguments of filter."""
  565. @wraps(method)
  566. def new_method(*args, **kwargs):
  567. param_dict = make_param_dict(method, args, kwargs)
  568. predicate = param_dict.get("predicate")
  569. if not callable(predicate):
  570. raise TypeError("Predicate should be a python function or a callable python object.")
  571. nreq_param_int = ['num_parallel_workers']
  572. check_param_type(nreq_param_int, param_dict, int)
  573. param_name = "input_columns"
  574. param = param_dict.get(param_name)
  575. if param is not None:
  576. check_columns(param, param_name)
  577. return method(*args, **kwargs)
  578. return new_method
  579. def check_repeat(method):
  580. """check the input arguments of repeat."""
  581. @wraps(method)
  582. def new_method(*args, **kwargs):
  583. param_dict = make_param_dict(method, args, kwargs)
  584. count = param_dict.get('count')
  585. if count is not None:
  586. check_count(count)
  587. return method(*args, **kwargs)
  588. return new_method
  589. def check_skip(method):
  590. """check the input arguments of skip."""
  591. @wraps(method)
  592. def new_method(*args, **kwargs):
  593. param_dict = make_param_dict(method, args, kwargs)
  594. count = param_dict.get('count')
  595. check_type(count, 'count', int)
  596. if count < 0:
  597. raise ValueError("Skip count must be positive integer or 0.")
  598. return method(*args, **kwargs)
  599. return new_method
  600. def check_take(method):
  601. """check the input arguments of take."""
  602. @wraps(method)
  603. def new_method(*args, **kwargs):
  604. param_dict = make_param_dict(method, args, kwargs)
  605. count = param_dict.get('count')
  606. check_count(count)
  607. return method(*args, **kwargs)
  608. return new_method
  609. def check_zip(method):
  610. """check the input arguments of zip."""
  611. @wraps(method)
  612. def new_method(*args, **kwargs):
  613. param_dict = make_param_dict(method, args, kwargs)
  614. # check datasets; required argument
  615. ds = param_dict.get("datasets")
  616. if ds is None:
  617. raise ValueError("datasets is not provided.")
  618. check_type(ds, 'datasets', tuple)
  619. return method(*args, **kwargs)
  620. return new_method
  621. def check_zip_dataset(method):
  622. """check the input arguments of zip method in `Dataset`."""
  623. @wraps(method)
  624. def new_method(*args, **kwargs):
  625. param_dict = make_param_dict(method, args, kwargs)
  626. # check datasets; required argument
  627. ds = param_dict.get("datasets")
  628. if ds is None:
  629. raise ValueError("datasets is not provided.")
  630. if not isinstance(ds, (tuple, datasets.Dataset)):
  631. raise TypeError("datasets is not tuple or of type Dataset.")
  632. return method(*args, **kwargs)
  633. return new_method
  634. def check_concat(method):
  635. """check the input arguments of concat method in `Dataset`."""
  636. @wraps(method)
  637. def new_method(*args, **kwargs):
  638. param_dict = make_param_dict(method, args, kwargs)
  639. # check datasets; required argument
  640. ds = param_dict.get("datasets")
  641. if ds is None:
  642. raise ValueError("datasets is not provided.")
  643. if not isinstance(ds, (list, datasets.Dataset)):
  644. raise TypeError("datasets is not list or of type Dataset.")
  645. return method(*args, **kwargs)
  646. return new_method
  647. def check_rename(method):
  648. """check the input arguments of rename."""
  649. @wraps(method)
  650. def new_method(*args, **kwargs):
  651. param_dict = make_param_dict(method, args, kwargs)
  652. req_param_columns = ['input_columns', 'output_columns']
  653. # check req_param_list; required arguments
  654. for param_name in req_param_columns:
  655. param = param_dict.get(param_name)
  656. if param is None:
  657. raise ValueError("{} is not provided.".format(param_name))
  658. check_columns(param, param_name)
  659. input_size, output_size = 1, 1
  660. if isinstance(param_dict.get(req_param_columns[0]), list):
  661. input_size = len(param_dict.get(req_param_columns[0]))
  662. if isinstance(param_dict.get(req_param_columns[1]), list):
  663. output_size = len(param_dict.get(req_param_columns[1]))
  664. if input_size != output_size:
  665. raise ValueError("Number of column in input_columns and output_columns is not equal.")
  666. return method(*args, **kwargs)
  667. return new_method
  668. def check_project(method):
  669. """check the input arguments of project."""
  670. @wraps(method)
  671. def new_method(*args, **kwargs):
  672. param_dict = make_param_dict(method, args, kwargs)
  673. # check columns; required argument
  674. columns = param_dict.get("columns")
  675. if columns is None:
  676. raise ValueError("columns is not provided.")
  677. check_columns(columns, 'columns')
  678. return method(*args, **kwargs)
  679. return new_method
  680. def check_shape(shape, name):
  681. if isinstance(shape, list):
  682. for element in shape:
  683. if not isinstance(element, int):
  684. raise TypeError(
  685. "Each element in {0} should be of type int. Got {1}.".format(name, type(element)))
  686. else:
  687. raise TypeError("Expected int list.")
  688. def check_add_column(method):
  689. """check the input arguments of add_column."""
  690. @wraps(method)
  691. def new_method(*args, **kwargs):
  692. param_dict = make_param_dict(method, args, kwargs)
  693. # check name; required argument
  694. name = param_dict.get("name")
  695. if not isinstance(name, str) or not name:
  696. raise TypeError("Expected non-empty string.")
  697. # check type; required argument
  698. de_type = param_dict.get("de_type")
  699. if de_type is not None:
  700. if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
  701. raise TypeError("Unknown column type.")
  702. else:
  703. raise TypeError("Expected non-empty string.")
  704. # check shape
  705. shape = param_dict.get("shape")
  706. if shape is not None:
  707. check_shape(shape, "shape")
  708. return method(*args, **kwargs)
  709. return new_method
  710. def check_cluedataset(method):
  711. """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
  712. @wraps(method)
  713. def new_method(*args, **kwargs):
  714. param_dict = make_param_dict(method, args, kwargs)
  715. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  716. # check dataset_files; required argument
  717. dataset_files = param_dict.get('dataset_files')
  718. if dataset_files is None:
  719. raise ValueError("dataset_files is not provided.")
  720. if not isinstance(dataset_files, (str, list)):
  721. raise TypeError("dataset_files should be of type str or a list of strings.")
  722. # check task
  723. task_param = param_dict.get('task')
  724. if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
  725. raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
  726. # check usage
  727. usage_param = param_dict.get('usage')
  728. if usage_param not in ['train', 'test', 'eval']:
  729. raise ValueError("usage should be train, test or eval")
  730. check_param_type(nreq_param_int, param_dict, int)
  731. check_sampler_shuffle_shard_options(param_dict)
  732. return method(*args, **kwargs)
  733. return new_method
  734. def check_textfiledataset(method):
  735. """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
  736. @wraps(method)
  737. def new_method(*args, **kwargs):
  738. param_dict = make_param_dict(method, args, kwargs)
  739. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  740. # check dataset_files; required argument
  741. dataset_files = param_dict.get('dataset_files')
  742. if dataset_files is None:
  743. raise ValueError("dataset_files is not provided.")
  744. if not isinstance(dataset_files, (str, list)):
  745. raise TypeError("dataset_files should be of type str or a list of strings.")
  746. check_param_type(nreq_param_int, param_dict, int)
  747. check_sampler_shuffle_shard_options(param_dict)
  748. return method(*args, **kwargs)
  749. return new_method
  750. def check_split(method):
  751. """check the input arguments of split."""
  752. @wraps(method)
  753. def new_method(*args, **kwargs):
  754. param_dict = make_param_dict(method, args, kwargs)
  755. nreq_param_list = ['sizes']
  756. nreq_param_bool = ['randomize']
  757. check_param_type(nreq_param_list, param_dict, list)
  758. check_param_type(nreq_param_bool, param_dict, bool)
  759. # check sizes: must be list of float or list of int
  760. sizes = param_dict.get('sizes')
  761. if not sizes:
  762. raise ValueError("sizes cannot be empty.")
  763. all_int = all(isinstance(item, int) for item in sizes)
  764. all_float = all(isinstance(item, float) for item in sizes)
  765. if not (all_int or all_float):
  766. raise ValueError("sizes should be list of int or list of float.")
  767. if all_int:
  768. all_positive = all(item > 0 for item in sizes)
  769. if not all_positive:
  770. raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
  771. if all_float:
  772. all_valid_percentages = all(0 < item <= 1 for item in sizes)
  773. if not all_valid_percentages:
  774. raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
  775. epsilon = 0.00001
  776. if not abs(sum(sizes) - 1) < epsilon:
  777. raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
  778. return method(*args, **kwargs)
  779. return new_method
  780. def check_gnn_graphdata(method):
  781. """check the input arguments of graphdata."""
  782. @wraps(method)
  783. def new_method(*args, **kwargs):
  784. param_dict = make_param_dict(method, args, kwargs)
  785. # check dataset_file; required argument
  786. dataset_file = param_dict.get('dataset_file')
  787. if dataset_file is None:
  788. raise ValueError("dataset_file is not provided.")
  789. check_dataset_file(dataset_file)
  790. nreq_param_int = ['num_parallel_workers']
  791. check_param_type(nreq_param_int, param_dict, int)
  792. return method(*args, **kwargs)
  793. return new_method
  794. def check_gnn_list_or_ndarray(param, param_name):
  795. """Check if the input parameter is list or numpy.ndarray."""
  796. if isinstance(param, list):
  797. for m in param:
  798. if not isinstance(m, int):
  799. raise TypeError(
  800. "Each member in {0} should be of type int. Got {1}.".format(param_name, type(m)))
  801. elif isinstance(param, np.ndarray):
  802. if not param.dtype == np.int32:
  803. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  804. param_name, param.dtype))
  805. else:
  806. raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
  807. param_name, type(param)))
  808. def check_gnn_get_all_nodes(method):
  809. """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
  810. @wraps(method)
  811. def new_method(*args, **kwargs):
  812. param_dict = make_param_dict(method, args, kwargs)
  813. # check node_type; required argument
  814. check_type(param_dict.get("node_type"), 'node_type', int)
  815. return method(*args, **kwargs)
  816. return new_method
  817. def check_gnn_get_all_edges(method):
  818. """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
  819. @wraps(method)
  820. def new_method(*args, **kwargs):
  821. param_dict = make_param_dict(method, args, kwargs)
  822. # check node_type; required argument
  823. check_type(param_dict.get("edge_type"), 'edge_type', int)
  824. return method(*args, **kwargs)
  825. return new_method
  826. def check_gnn_get_nodes_from_edges(method):
  827. """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
  828. @wraps(method)
  829. def new_method(*args, **kwargs):
  830. param_dict = make_param_dict(method, args, kwargs)
  831. # check edge_list; required argument
  832. check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list')
  833. return method(*args, **kwargs)
  834. return new_method
  835. def check_gnn_get_all_neighbors(method):
  836. """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
  837. @wraps(method)
  838. def new_method(*args, **kwargs):
  839. param_dict = make_param_dict(method, args, kwargs)
  840. # check node_list; required argument
  841. check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
  842. # check neighbor_type; required argument
  843. check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
  844. return method(*args, **kwargs)
  845. return new_method
  846. def check_gnn_get_sampled_neighbors(method):
  847. """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
  848. @wraps(method)
  849. def new_method(*args, **kwargs):
  850. param_dict = make_param_dict(method, args, kwargs)
  851. # check node_list; required argument
  852. check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
  853. # check neighbor_nums; required argument
  854. neighbor_nums = param_dict.get("neighbor_nums")
  855. check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
  856. if not neighbor_nums or len(neighbor_nums) > 6:
  857. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  858. 'neighbor_nums', len(neighbor_nums)))
  859. # check neighbor_types; required argument
  860. neighbor_types = param_dict.get("neighbor_types")
  861. check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
  862. if not neighbor_types or len(neighbor_types) > 6:
  863. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  864. 'neighbor_types', len(neighbor_types)))
  865. if len(neighbor_nums) != len(neighbor_types):
  866. raise ValueError(
  867. "The number of members of neighbor_nums and neighbor_types is inconsistent")
  868. return method(*args, **kwargs)
  869. return new_method
  870. def check_gnn_get_neg_sampled_neighbors(method):
  871. """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
  872. @wraps(method)
  873. def new_method(*args, **kwargs):
  874. param_dict = make_param_dict(method, args, kwargs)
  875. # check node_list; required argument
  876. check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
  877. # check neg_neighbor_num; required argument
  878. check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int)
  879. # check neg_neighbor_type; required argument
  880. check_type(param_dict.get("neg_neighbor_type"),
  881. 'neg_neighbor_type', int)
  882. return method(*args, **kwargs)
  883. return new_method
  884. def check_gnn_random_walk(method):
  885. """A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
  886. @wraps(method)
  887. def new_method(*args, **kwargs):
  888. param_dict = make_param_dict(method, args, kwargs)
  889. # check node_list; required argument
  890. check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes')
  891. # check meta_path; required argument
  892. check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')
  893. return method(*args, **kwargs)
  894. return new_method
  895. def check_aligned_list(param, param_name, member_type):
  896. """Check whether the structure of each member of the list is the same."""
  897. if not isinstance(param, list):
  898. raise TypeError("Parameter {0} is not a list".format(param_name))
  899. if not param:
  900. raise TypeError(
  901. "Parameter {0} or its members are empty".format(param_name))
  902. member_have_list = None
  903. list_len = None
  904. for member in param:
  905. if isinstance(member, list):
  906. check_aligned_list(member, param_name, member_type)
  907. if member_have_list not in (None, True):
  908. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  909. param_name))
  910. if list_len is not None and len(member) != list_len:
  911. raise TypeError("The size of each member of parameter {0} is inconsistent".format(
  912. param_name))
  913. member_have_list = True
  914. list_len = len(member)
  915. else:
  916. if not isinstance(member, member_type):
  917. raise TypeError("Each member in {0} should be of type int. Got {1}.".format(
  918. param_name, type(member)))
  919. if member_have_list not in (None, False):
  920. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  921. param_name))
  922. member_have_list = False
  923. def check_gnn_get_node_feature(method):
  924. """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
  925. @wraps(method)
  926. def new_method(*args, **kwargs):
  927. param_dict = make_param_dict(method, args, kwargs)
  928. # check node_list; required argument
  929. node_list = param_dict.get("node_list")
  930. if isinstance(node_list, list):
  931. check_aligned_list(node_list, 'node_list', int)
  932. elif isinstance(node_list, np.ndarray):
  933. if not node_list.dtype == np.int32:
  934. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  935. node_list, node_list.dtype))
  936. else:
  937. raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
  938. 'node_list', type(node_list)))
  939. # check feature_types; required argument
  940. check_gnn_list_or_ndarray(param_dict.get(
  941. "feature_types"), 'feature_types')
  942. return method(*args, **kwargs)
  943. return new_method
  944. def check_numpyslicesdataset(method):
  945. """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset)."""
  946. @wraps(method)
  947. def new_method(*args, **kwargs):
  948. param_dict = make_param_dict(method, args, kwargs)
  949. # check data; required argument
  950. data = param_dict.get('data')
  951. if not isinstance(data, (list, tuple, dict, np.ndarray)):
  952. raise TypeError("Unsupported data type: {}, only support some common python data type, "
  953. "like list, tuple, dict, and numpy array.".format(type(data)))
  954. if isinstance(data, tuple) and not isinstance(data[0], (list, np.ndarray)):
  955. raise TypeError("Unsupported data type: when input is tuple, only support some common python "
  956. "data type, like tuple of lists and tuple of numpy arrays.")
  957. if not data:
  958. raise ValueError("Input data is empty.")
  959. # check column_names
  960. column_names = param_dict.get('column_names')
  961. if column_names is not None:
  962. check_columns(column_names, "column_names")
  963. # check num of input column in column_names
  964. column_num = 1 if isinstance(column_names, str) else len(column_names)
  965. if isinstance(data, dict):
  966. data_column = len(list(data.keys()))
  967. if column_num != data_column:
  968. raise ValueError("Num of input column names is {0}, but required is {1}."
  969. .format(column_num, data_column))
  970. elif isinstance(data, tuple):
  971. if column_num != len(data):
  972. raise ValueError("Num of input column names is {0}, but required is {1}."
  973. .format(column_num, len(data)))
  974. else:
  975. if column_num != 1:
  976. raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
  977. .format(column_num, 1))
  978. return method(*args, **kwargs)
  979. return new_method