You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 45 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License foNtest_resr the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Built-in validators.
  17. """
  18. import inspect as ins
  19. import os
  20. import re
  21. from functools import wraps
  22. import numpy as np
  23. from mindspore._c_expression import typing
  24. from mindspore.dataset.callback import DSCallback
  25. from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
  26. INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
  27. validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
  28. check_columns, check_pos_int32
  29. from . import datasets
  30. from . import samplers
  31. from . import cache_client
  32. from .. import callback
  33. def check_imagefolderdatasetv2(method):
  34. """A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
  35. @wraps(method)
  36. def new_method(self, *args, **kwargs):
  37. _, param_dict = parse_user_args(method, *args, **kwargs)
  38. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  39. nreq_param_bool = ['shuffle', 'decode']
  40. nreq_param_list = ['extensions']
  41. nreq_param_dict = ['class_indexing']
  42. dataset_dir = param_dict.get('dataset_dir')
  43. check_dir(dataset_dir)
  44. validate_dataset_param_value(nreq_param_int, param_dict, int)
  45. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  46. validate_dataset_param_value(nreq_param_list, param_dict, list)
  47. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  48. check_sampler_shuffle_shard_options(param_dict)
  49. return method(self, *args, **kwargs)
  50. return new_method
  51. def check_mnist_cifar_dataset(method):
  52. """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
  53. @wraps(method)
  54. def new_method(self, *args, **kwargs):
  55. _, param_dict = parse_user_args(method, *args, **kwargs)
  56. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  57. nreq_param_bool = ['shuffle']
  58. dataset_dir = param_dict.get('dataset_dir')
  59. check_dir(dataset_dir)
  60. validate_dataset_param_value(nreq_param_int, param_dict, int)
  61. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  62. check_sampler_shuffle_shard_options(param_dict)
  63. return method(self, *args, **kwargs)
  64. return new_method
  65. def check_manifestdataset(method):
  66. """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
  67. @wraps(method)
  68. def new_method(self, *args, **kwargs):
  69. _, param_dict = parse_user_args(method, *args, **kwargs)
  70. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  71. nreq_param_bool = ['shuffle', 'decode']
  72. nreq_param_str = ['usage']
  73. nreq_param_dict = ['class_indexing']
  74. dataset_file = param_dict.get('dataset_file')
  75. check_file(dataset_file)
  76. validate_dataset_param_value(nreq_param_int, param_dict, int)
  77. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  78. validate_dataset_param_value(nreq_param_str, param_dict, str)
  79. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  80. check_sampler_shuffle_shard_options(param_dict)
  81. return method(self, *args, **kwargs)
  82. return new_method
  83. def check_tfrecorddataset(method):
  84. """A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
  85. @wraps(method)
  86. def new_method(self, *args, **kwargs):
  87. _, param_dict = parse_user_args(method, *args, **kwargs)
  88. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  89. nreq_param_list = ['columns_list']
  90. nreq_param_bool = ['shard_equal_rows']
  91. dataset_files = param_dict.get('dataset_files')
  92. if not isinstance(dataset_files, (str, list)):
  93. raise TypeError("dataset_files should be of type str or a list of strings.")
  94. validate_dataset_param_value(nreq_param_int, param_dict, int)
  95. validate_dataset_param_value(nreq_param_list, param_dict, list)
  96. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  97. check_sampler_shuffle_shard_options(param_dict)
  98. return method(self, *args, **kwargs)
  99. return new_method
  100. def check_vocdataset(method):
  101. """A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
  102. @wraps(method)
  103. def new_method(self, *args, **kwargs):
  104. _, param_dict = parse_user_args(method, *args, **kwargs)
  105. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  106. nreq_param_bool = ['shuffle', 'decode']
  107. nreq_param_dict = ['class_indexing']
  108. dataset_dir = param_dict.get('dataset_dir')
  109. check_dir(dataset_dir)
  110. task = param_dict.get('task')
  111. type_check(task, (str,), "task")
  112. mode = param_dict.get('mode')
  113. type_check(mode, (str,), "mode")
  114. if task == "Segmentation":
  115. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt")
  116. if param_dict.get('class_indexing') is not None:
  117. raise ValueError("class_indexing is invalid in Segmentation task")
  118. elif task == "Detection":
  119. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt")
  120. else:
  121. raise ValueError("Invalid task : " + task)
  122. check_file(imagesets_file)
  123. validate_dataset_param_value(nreq_param_int, param_dict, int)
  124. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  125. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  126. check_sampler_shuffle_shard_options(param_dict)
  127. return method(self, *args, **kwargs)
  128. return new_method
  129. def check_cocodataset(method):
  130. """A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
  131. @wraps(method)
  132. def new_method(self, *args, **kwargs):
  133. _, param_dict = parse_user_args(method, *args, **kwargs)
  134. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  135. nreq_param_bool = ['shuffle', 'decode']
  136. dataset_dir = param_dict.get('dataset_dir')
  137. check_dir(dataset_dir)
  138. annotation_file = param_dict.get('annotation_file')
  139. check_file(annotation_file)
  140. task = param_dict.get('task')
  141. type_check(task, (str,), "task")
  142. if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
  143. raise ValueError("Invalid task type")
  144. validate_dataset_param_value(nreq_param_int, param_dict, int)
  145. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  146. sampler = param_dict.get('sampler')
  147. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  148. raise ValueError("CocoDataset doesn't support PKSampler")
  149. check_sampler_shuffle_shard_options(param_dict)
  150. return method(self, *args, **kwargs)
  151. return new_method
  152. def check_celebadataset(method):
  153. """A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
  154. @wraps(method)
  155. def new_method(self, *args, **kwargs):
  156. _, param_dict = parse_user_args(method, *args, **kwargs)
  157. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  158. nreq_param_bool = ['shuffle', 'decode']
  159. nreq_param_list = ['extensions']
  160. nreq_param_str = ['dataset_type']
  161. dataset_dir = param_dict.get('dataset_dir')
  162. check_dir(dataset_dir)
  163. validate_dataset_param_value(nreq_param_int, param_dict, int)
  164. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  165. validate_dataset_param_value(nreq_param_list, param_dict, list)
  166. validate_dataset_param_value(nreq_param_str, param_dict, str)
  167. dataset_type = param_dict.get('dataset_type')
  168. if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'):
  169. raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.")
  170. check_sampler_shuffle_shard_options(param_dict)
  171. sampler = param_dict.get('sampler')
  172. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  173. raise ValueError("CelebADataset does not support PKSampler.")
  174. return method(self, *args, **kwargs)
  175. return new_method
  176. def check_save(method):
  177. """A wrapper that wrap a parameter checker to the save op."""
  178. @wraps(method)
  179. def new_method(self, *args, **kwargs):
  180. _, param_dict = parse_user_args(method, *args, **kwargs)
  181. nreq_param_int = ['num_files']
  182. nreq_param_str = ['file_name', 'file_type']
  183. validate_dataset_param_value(nreq_param_int, param_dict, int)
  184. if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
  185. raise ValueError("num_files should between {} and {}.".format(1, 1000))
  186. validate_dataset_param_value(nreq_param_str, param_dict, str)
  187. if param_dict.get('file_type') != 'mindrecord':
  188. raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
  189. return method(self, *args, **kwargs)
  190. return new_method
  191. def check_minddataset(method):
  192. """A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
  193. @wraps(method)
  194. def new_method(self, *args, **kwargs):
  195. _, param_dict = parse_user_args(method, *args, **kwargs)
  196. nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
  197. nreq_param_list = ['columns_list']
  198. nreq_param_dict = ['padded_sample']
  199. dataset_file = param_dict.get('dataset_file')
  200. if isinstance(dataset_file, list):
  201. if len(dataset_file) > 4096:
  202. raise ValueError("length of dataset_file should less than or equal to {}.".format(4096))
  203. for f in dataset_file:
  204. check_file(f)
  205. else:
  206. check_file(dataset_file)
  207. validate_dataset_param_value(nreq_param_int, param_dict, int)
  208. validate_dataset_param_value(nreq_param_list, param_dict, list)
  209. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  210. check_sampler_shuffle_shard_options(param_dict)
  211. check_padding_options(param_dict)
  212. return method(self, *args, **kwargs)
  213. return new_method
  214. def check_generatordataset(method):
  215. """A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
  216. @wraps(method)
  217. def new_method(self, *args, **kwargs):
  218. _, param_dict = parse_user_args(method, *args, **kwargs)
  219. source = param_dict.get('source')
  220. if not callable(source):
  221. try:
  222. iter(source)
  223. except TypeError:
  224. raise TypeError("source should be callable, iterable or random accessible")
  225. column_names = param_dict.get('column_names')
  226. if column_names is not None:
  227. check_columns(column_names, "column_names")
  228. schema = param_dict.get('schema')
  229. if column_names is None and schema is None:
  230. raise ValueError("Neither columns_names not schema are provided.")
  231. if schema is not None:
  232. if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
  233. raise ValueError("schema should be a path to schema file or a schema object.")
  234. # check optional argument
  235. nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
  236. validate_dataset_param_value(nreq_param_int, param_dict, int)
  237. nreq_param_list = ["column_types"]
  238. validate_dataset_param_value(nreq_param_list, param_dict, list)
  239. nreq_param_bool = ["shuffle"]
  240. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  241. num_shards = param_dict.get("num_shards")
  242. shard_id = param_dict.get("shard_id")
  243. if (num_shards is None) != (shard_id is None):
  244. # These two parameters appear together.
  245. raise ValueError("num_shards and shard_id need to be passed in together")
  246. if num_shards is not None:
  247. check_pos_int32(num_shards, "num_shards")
  248. if shard_id >= num_shards:
  249. raise ValueError("shard_id should be less than num_shards.")
  250. sampler = param_dict.get("sampler")
  251. if sampler is not None:
  252. if isinstance(sampler, samplers.PKSampler):
  253. raise ValueError("PKSampler is not supported by GeneratorDataset")
  254. if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  255. samplers.RandomSampler, samplers.SubsetRandomSampler,
  256. samplers.WeightedRandomSampler, samplers.Sampler)):
  257. try:
  258. iter(sampler)
  259. except TypeError:
  260. raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
  261. if sampler is not None and not hasattr(source, "__getitem__"):
  262. raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
  263. if num_shards is not None and not hasattr(source, "__getitem__"):
  264. raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
  265. return method(self, *args, **kwargs)
  266. return new_method
  267. def check_random_dataset(method):
  268. """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
  269. @wraps(method)
  270. def new_method(self, *args, **kwargs):
  271. _, param_dict = parse_user_args(method, *args, **kwargs)
  272. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
  273. nreq_param_bool = ['shuffle']
  274. nreq_param_list = ['columns_list']
  275. validate_dataset_param_value(nreq_param_int, param_dict, int)
  276. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  277. validate_dataset_param_value(nreq_param_list, param_dict, list)
  278. check_sampler_shuffle_shard_options(param_dict)
  279. return method(self, *args, **kwargs)
  280. return new_method
  281. def check_pad_info(key, val):
  282. """check the key and value pair of pad_info in batch"""
  283. type_check(key, (str,), "key in pad_info")
  284. if val is not None:
  285. assert len(val) == 2, "value of pad_info should be a tuple of size 2"
  286. type_check(val, (tuple,), "value in pad_info")
  287. if val[0] is not None:
  288. type_check(val[0], (list,), "pad_shape")
  289. for dim in val[0]:
  290. if dim is not None:
  291. type_check(dim, (int,), "dim in pad_shape")
  292. assert dim > 0, "pad shape should be positive integers"
  293. if val[1] is not None:
  294. type_check(val[1], (int, float, str, bytes), "pad_value")
  295. def check_bucket_batch_by_length(method):
  296. """check the input arguments of bucket_batch_by_length."""
  297. @wraps(method)
  298. def new_method(self, *args, **kwargs):
  299. [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
  300. pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
  301. nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
  302. type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
  303. nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
  304. type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
  305. # check column_names: must be list of string.
  306. check_columns(column_names, "column_names")
  307. if element_length_function is None and len(column_names) != 1:
  308. raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
  309. # check bucket_boundaries: must be list of int, positive and strictly increasing
  310. if not bucket_boundaries:
  311. raise ValueError("bucket_boundaries cannot be empty.")
  312. all_int = all(isinstance(item, int) for item in bucket_boundaries)
  313. if not all_int:
  314. raise TypeError("bucket_boundaries should be a list of int.")
  315. all_non_negative = all(item > 0 for item in bucket_boundaries)
  316. if not all_non_negative:
  317. raise ValueError("bucket_boundaries must only contain positive numbers.")
  318. for i in range(len(bucket_boundaries) - 1):
  319. if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
  320. raise ValueError("bucket_boundaries should be strictly increasing.")
  321. # check bucket_batch_sizes: must be list of int and positive
  322. if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
  323. raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
  324. all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
  325. if not all_int:
  326. raise TypeError("bucket_batch_sizes should be a list of int.")
  327. all_non_negative = all(item > 0 for item in bucket_batch_sizes)
  328. if not all_non_negative:
  329. raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
  330. if pad_info is not None:
  331. type_check(pad_info, (dict,), "pad_info")
  332. for k, v in pad_info.items():
  333. check_pad_info(k, v)
  334. return method(self, *args, **kwargs)
  335. return new_method
  336. def check_batch(method):
  337. """check the input arguments of batch."""
  338. @wraps(method)
  339. def new_method(self, *args, **kwargs):
  340. [batch_size, drop_remainder, num_parallel_workers, per_batch_map,
  341. input_columns, pad_info], param_dict = parse_user_args(method, *args, **kwargs)
  342. if not (isinstance(batch_size, int) or (callable(batch_size))):
  343. raise TypeError("batch_size should either be an int or a callable.")
  344. if callable(batch_size):
  345. sig = ins.signature(batch_size)
  346. if len(sig.parameters) != 1:
  347. raise ValueError("batch_size callable should take one parameter (BatchInfo).")
  348. if num_parallel_workers is not None:
  349. check_num_parallel_workers(num_parallel_workers)
  350. type_check(drop_remainder, (bool,), "drop_remainder")
  351. if (pad_info is not None) and (per_batch_map is not None):
  352. raise ValueError("pad_info and per_batch_map can't both be set")
  353. if pad_info is not None:
  354. type_check(param_dict["pad_info"], (dict,), "pad_info")
  355. for k, v in param_dict.get('pad_info').items():
  356. check_pad_info(k, v)
  357. if input_columns is not None:
  358. check_columns(input_columns, "input_columns")
  359. if (per_batch_map is None) != (input_columns is None):
  360. # These two parameters appear together.
  361. raise ValueError("per_batch_map and input_columns need to be passed in together.")
  362. if input_columns is not None:
  363. if not input_columns: # Check whether input_columns is empty.
  364. raise ValueError("input_columns can not be empty")
  365. if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
  366. raise ValueError("the signature of per_batch_map should match with input columns")
  367. return method(self, *args, **kwargs)
  368. return new_method
  369. def check_sync_wait(method):
  370. """check the input arguments of sync_wait."""
  371. @wraps(method)
  372. def new_method(self, *args, **kwargs):
  373. [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
  374. type_check(condition_name, (str,), "condition_name")
  375. type_check(num_batch, (int,), "num_batch")
  376. return method(self, *args, **kwargs)
  377. return new_method
  378. def check_shuffle(method):
  379. """check the input arguments of shuffle."""
  380. @wraps(method)
  381. def new_method(self, *args, **kwargs):
  382. [buffer_size], _ = parse_user_args(method, *args, **kwargs)
  383. type_check(buffer_size, (int,), "buffer_size")
  384. check_value(buffer_size, [2, INT32_MAX], "buffer_size")
  385. return method(self, *args, **kwargs)
  386. return new_method
  387. def check_map(method):
  388. """check the input arguments of map."""
  389. @wraps(method)
  390. def new_method(self, *args, **kwargs):
  391. [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache,
  392. callbacks], _ = \
  393. parse_user_args(method, *args, **kwargs)
  394. nreq_param_columns = ['input_columns', 'output_columns', 'columns_order']
  395. if columns_order is not None:
  396. type_check(columns_order, (list,), "columns_order")
  397. if num_parallel_workers is not None:
  398. check_num_parallel_workers(num_parallel_workers)
  399. type_check(python_multiprocessing, (bool,), "python_multiprocessing")
  400. if cache is not None:
  401. type_check(cache, (cache_client.DatasetCache,), "cache")
  402. if callbacks is not None:
  403. if isinstance(callbacks, (list, tuple)):
  404. type_check_list(callbacks, (callback.DSCallback,), "callbacks")
  405. else:
  406. type_check(callbacks, (callback.DSCallback,), "callbacks")
  407. for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, columns_order]):
  408. if param is not None:
  409. check_columns(param, param_name)
  410. if callbacks is not None:
  411. type_check(callbacks, (list, DSCallback), "callbacks")
  412. return method(self, *args, **kwargs)
  413. return new_method
  414. def check_filter(method):
  415. """"check the input arguments of filter."""
  416. @wraps(method)
  417. def new_method(self, *args, **kwargs):
  418. [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
  419. if not callable(predicate):
  420. raise TypeError("Predicate should be a python function or a callable python object.")
  421. check_num_parallel_workers(num_parallel_workers)
  422. if num_parallel_workers is not None:
  423. check_num_parallel_workers(num_parallel_workers)
  424. if input_columns is not None:
  425. check_columns(input_columns, "input_columns")
  426. return method(self, *args, **kwargs)
  427. return new_method
  428. def check_repeat(method):
  429. """check the input arguments of repeat."""
  430. @wraps(method)
  431. def new_method(self, *args, **kwargs):
  432. [count], _ = parse_user_args(method, *args, **kwargs)
  433. type_check(count, (int, type(None)), "repeat")
  434. if isinstance(count, int):
  435. if (count <= 0 and count != -1) or count > INT32_MAX:
  436. raise ValueError("count should be either -1 or positive integer.")
  437. return method(self, *args, **kwargs)
  438. return new_method
  439. def check_skip(method):
  440. """check the input arguments of skip."""
  441. @wraps(method)
  442. def new_method(self, *args, **kwargs):
  443. [count], _ = parse_user_args(method, *args, **kwargs)
  444. type_check(count, (int,), "count")
  445. check_value(count, (-1, INT32_MAX), "count")
  446. return method(self, *args, **kwargs)
  447. return new_method
  448. def check_take(method):
  449. """check the input arguments of take."""
  450. @wraps(method)
  451. def new_method(self, *args, **kwargs):
  452. [count], _ = parse_user_args(method, *args, **kwargs)
  453. type_check(count, (int,), "count")
  454. if (count <= 0 and count != -1) or count > INT32_MAX:
  455. raise ValueError("count should be either -1 or positive integer.")
  456. return method(self, *args, **kwargs)
  457. return new_method
  458. def check_positive_int32(method):
  459. """check whether the input argument is positive and int, only works for functions with one input."""
  460. @wraps(method)
  461. def new_method(self, *args, **kwargs):
  462. [count], param_dict = parse_user_args(method, *args, **kwargs)
  463. para_name = None
  464. for key in list(param_dict.keys()):
  465. if key not in ['self', 'cls']:
  466. para_name = key
  467. # Need to get default value of param
  468. if count is not None:
  469. check_pos_int32(count, para_name)
  470. return method(self, *args, **kwargs)
  471. return new_method
  472. def check_device_send(method):
  473. """check the input argument for to_device and device_que."""
  474. @wraps(method)
  475. def new_method(self, *args, **kwargs):
  476. param, param_dict = parse_user_args(method, *args, **kwargs)
  477. para_list = list(param_dict.keys())
  478. if "prefetch_size" in para_list:
  479. if param[0] is not None:
  480. check_pos_int32(param[0], "prefetch_size")
  481. type_check(param[1], (bool,), "send_epoch_end")
  482. else:
  483. type_check(param[0], (bool,), "send_epoch_end")
  484. return method(self, *args, **kwargs)
  485. return new_method
  486. def check_zip(method):
  487. """check the input arguments of zip."""
  488. @wraps(method)
  489. def new_method(*args, **kwargs):
  490. [ds], _ = parse_user_args(method, *args, **kwargs)
  491. type_check(ds, (tuple,), "datasets")
  492. return method(*args, **kwargs)
  493. return new_method
  494. def check_zip_dataset(method):
  495. """check the input arguments of zip method in `Dataset`."""
  496. @wraps(method)
  497. def new_method(self, *args, **kwargs):
  498. [ds], _ = parse_user_args(method, *args, **kwargs)
  499. type_check(ds, (tuple, datasets.Dataset), "datasets")
  500. return method(self, *args, **kwargs)
  501. return new_method
  502. def check_concat(method):
  503. """check the input arguments of concat method in `Dataset`."""
  504. @wraps(method)
  505. def new_method(self, *args, **kwargs):
  506. [ds], _ = parse_user_args(method, *args, **kwargs)
  507. type_check(ds, (list, datasets.Dataset), "datasets")
  508. if isinstance(ds, list):
  509. type_check_list(ds, (datasets.Dataset,), "dataset")
  510. return method(self, *args, **kwargs)
  511. return new_method
  512. def check_rename(method):
  513. """check the input arguments of rename."""
  514. @wraps(method)
  515. def new_method(self, *args, **kwargs):
  516. values, _ = parse_user_args(method, *args, **kwargs)
  517. req_param_columns = ['input_columns', 'output_columns']
  518. for param_name, param in zip(req_param_columns, values):
  519. check_columns(param, param_name)
  520. input_size, output_size = 1, 1
  521. input_columns, output_columns = values
  522. if isinstance(input_columns, list):
  523. input_size = len(input_columns)
  524. if isinstance(output_columns, list):
  525. output_size = len(output_columns)
  526. if input_size != output_size:
  527. raise ValueError("Number of column in input_columns and output_columns is not equal.")
  528. return method(self, *args, **kwargs)
  529. return new_method
  530. def check_project(method):
  531. """check the input arguments of project."""
  532. @wraps(method)
  533. def new_method(self, *args, **kwargs):
  534. [columns], _ = parse_user_args(method, *args, **kwargs)
  535. check_columns(columns, 'columns')
  536. return method(self, *args, **kwargs)
  537. return new_method
  538. def check_add_column(method):
  539. """check the input arguments of add_column."""
  540. @wraps(method)
  541. def new_method(self, *args, **kwargs):
  542. [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
  543. type_check(name, (str,), "name")
  544. if not name:
  545. raise TypeError("Expected non-empty string.")
  546. if de_type is not None:
  547. if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
  548. raise TypeError("Unknown column type.")
  549. else:
  550. raise TypeError("Expected non-empty string.")
  551. if shape is not None:
  552. type_check(shape, (list,), "shape")
  553. type_check_list(shape, (int,), "shape")
  554. return method(self, *args, **kwargs)
  555. return new_method
  556. def check_cluedataset(method):
  557. """A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
  558. @wraps(method)
  559. def new_method(self, *args, **kwargs):
  560. _, param_dict = parse_user_args(method, *args, **kwargs)
  561. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  562. dataset_files = param_dict.get('dataset_files')
  563. type_check(dataset_files, (str, list), "dataset files")
  564. # check task
  565. task_param = param_dict.get('task')
  566. if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
  567. raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
  568. # check usage
  569. usage_param = param_dict.get('usage')
  570. if usage_param not in ['train', 'test', 'eval']:
  571. raise ValueError("usage should be train, test or eval")
  572. validate_dataset_param_value(nreq_param_int, param_dict, int)
  573. check_sampler_shuffle_shard_options(param_dict)
  574. return method(self, *args, **kwargs)
  575. return new_method
  576. def check_csvdataset(method):
  577. """A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
  578. @wraps(method)
  579. def new_method(self, *args, **kwargs):
  580. _, param_dict = parse_user_args(method, *args, **kwargs)
  581. nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id']
  582. # check dataset_files; required argument
  583. dataset_files = param_dict.get('dataset_files')
  584. type_check(dataset_files, (str, list), "dataset files")
  585. # check num_samples
  586. num_samples = param_dict.get('num_samples')
  587. check_value(num_samples, [-1, INT32_MAX], "num_samples")
  588. # check field_delim
  589. field_delim = param_dict.get('field_delim')
  590. type_check(field_delim, (str,), 'field delim')
  591. if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
  592. raise ValueError("field_delim is not legal.")
  593. # check column_defaults
  594. column_defaults = param_dict.get('column_defaults')
  595. if column_defaults is not None:
  596. if not isinstance(column_defaults, list):
  597. raise TypeError("column_defaults should be type of list.")
  598. for item in column_defaults:
  599. if not isinstance(item, (str, int, float)):
  600. raise TypeError("column type is not legal.")
  601. # check column_names: must be list of string.
  602. column_names = param_dict.get("column_names")
  603. if column_names is not None:
  604. all_string = all(isinstance(item, str) for item in column_names)
  605. if not all_string:
  606. raise TypeError("column_names should be a list of str.")
  607. validate_dataset_param_value(nreq_param_int, param_dict, int)
  608. check_sampler_shuffle_shard_options(param_dict)
  609. return method(self, *args, **kwargs)
  610. return new_method
  611. def check_textfiledataset(method):
  612. """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
  613. @wraps(method)
  614. def new_method(self, *args, **kwargs):
  615. _, param_dict = parse_user_args(method, *args, **kwargs)
  616. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  617. dataset_files = param_dict.get('dataset_files')
  618. type_check(dataset_files, (str, list), "dataset files")
  619. validate_dataset_param_value(nreq_param_int, param_dict, int)
  620. check_sampler_shuffle_shard_options(param_dict)
  621. return method(self, *args, **kwargs)
  622. return new_method
  623. def check_split(method):
  624. """check the input arguments of split."""
  625. @wraps(method)
  626. def new_method(self, *args, **kwargs):
  627. [sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
  628. type_check(sizes, (list,), "sizes")
  629. type_check(randomize, (bool,), "randomize")
  630. # check sizes: must be list of float or list of int
  631. if not sizes:
  632. raise ValueError("sizes cannot be empty.")
  633. all_int = all(isinstance(item, int) for item in sizes)
  634. all_float = all(isinstance(item, float) for item in sizes)
  635. if not (all_int or all_float):
  636. raise ValueError("sizes should be list of int or list of float.")
  637. if all_int:
  638. all_positive = all(item > 0 for item in sizes)
  639. if not all_positive:
  640. raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
  641. if all_float:
  642. all_valid_percentages = all(0 < item <= 1 for item in sizes)
  643. if not all_valid_percentages:
  644. raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
  645. epsilon = 0.00001
  646. if not abs(sum(sizes) - 1) < epsilon:
  647. raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
  648. return method(self, *args, **kwargs)
  649. return new_method
  650. def check_hostname(hostname):
  651. if not hostname or len(hostname) > 255:
  652. return False
  653. if hostname[-1] == ".":
  654. hostname = hostname[:-1] # strip exactly one dot from the right, if present
  655. allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
  656. return all(allowed.match(x) for x in hostname.split("."))
  657. def check_gnn_graphdata(method):
  658. """check the input arguments of graphdata."""
  659. @wraps(method)
  660. def new_method(self, *args, **kwargs):
  661. [dataset_file, num_parallel_workers, working_mode, hostname,
  662. port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
  663. check_file(dataset_file)
  664. if num_parallel_workers is not None:
  665. check_num_parallel_workers(num_parallel_workers)
  666. type_check(hostname, (str,), "hostname")
  667. if check_hostname(hostname) is False:
  668. raise ValueError("The hostname is illegal")
  669. type_check(working_mode, (str,), "working_mode")
  670. if working_mode not in {'local', 'client', 'server'}:
  671. raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
  672. type_check(port, (int,), "port")
  673. check_value(port, (1024, 65535), "port")
  674. type_check(num_client, (int,), "num_client")
  675. check_value(num_client, (1, 255), "num_client")
  676. type_check(auto_shutdown, (bool,), "auto_shutdown")
  677. return method(self, *args, **kwargs)
  678. return new_method
  679. def check_gnn_get_all_nodes(method):
  680. """A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
  681. @wraps(method)
  682. def new_method(self, *args, **kwargs):
  683. [node_type], _ = parse_user_args(method, *args, **kwargs)
  684. type_check(node_type, (int,), "node_type")
  685. return method(self, *args, **kwargs)
  686. return new_method
  687. def check_gnn_get_all_edges(method):
  688. """A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
  689. @wraps(method)
  690. def new_method(self, *args, **kwargs):
  691. [edge_type], _ = parse_user_args(method, *args, **kwargs)
  692. type_check(edge_type, (int,), "edge_type")
  693. return method(self, *args, **kwargs)
  694. return new_method
  695. def check_gnn_get_nodes_from_edges(method):
  696. """A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
  697. @wraps(method)
  698. def new_method(self, *args, **kwargs):
  699. [edge_list], _ = parse_user_args(method, *args, **kwargs)
  700. check_gnn_list_or_ndarray(edge_list, "edge_list")
  701. return method(self, *args, **kwargs)
  702. return new_method
  703. def check_gnn_get_all_neighbors(method):
  704. """A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
  705. @wraps(method)
  706. def new_method(self, *args, **kwargs):
  707. [node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
  708. check_gnn_list_or_ndarray(node_list, 'node_list')
  709. type_check(neighbour_type, (int,), "neighbour_type")
  710. return method(self, *args, **kwargs)
  711. return new_method
  712. def check_gnn_get_sampled_neighbors(method):
  713. """A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
  714. @wraps(method)
  715. def new_method(self, *args, **kwargs):
  716. [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
  717. check_gnn_list_or_ndarray(node_list, 'node_list')
  718. check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
  719. if not neighbor_nums or len(neighbor_nums) > 6:
  720. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  721. 'neighbor_nums', len(neighbor_nums)))
  722. check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
  723. if not neighbor_types or len(neighbor_types) > 6:
  724. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  725. 'neighbor_types', len(neighbor_types)))
  726. if len(neighbor_nums) != len(neighbor_types):
  727. raise ValueError(
  728. "The number of members of neighbor_nums and neighbor_types is inconsistent")
  729. return method(self, *args, **kwargs)
  730. return new_method
  731. def check_gnn_get_neg_sampled_neighbors(method):
  732. """A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
  733. @wraps(method)
  734. def new_method(self, *args, **kwargs):
  735. [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
  736. check_gnn_list_or_ndarray(node_list, 'node_list')
  737. type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
  738. type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
  739. return method(self, *args, **kwargs)
  740. return new_method
  741. def check_gnn_random_walk(method):
  742. """A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
  743. @wraps(method)
  744. def new_method(self, *args, **kwargs):
  745. [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
  746. **kwargs)
  747. check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
  748. check_gnn_list_or_ndarray(meta_path, 'meta_path')
  749. type_check(step_home_param, (float,), "step_home_param")
  750. type_check(step_away_param, (float,), "step_away_param")
  751. type_check(default_node, (int,), "default_node")
  752. check_value(default_node, (-1, INT32_MAX), "default_node")
  753. return method(self, *args, **kwargs)
  754. return new_method
  755. def check_aligned_list(param, param_name, member_type):
  756. """Check whether the structure of each member of the list is the same."""
  757. type_check(param, (list,), "param")
  758. if not param:
  759. raise TypeError(
  760. "Parameter {0} or its members are empty".format(param_name))
  761. member_have_list = None
  762. list_len = None
  763. for member in param:
  764. if isinstance(member, list):
  765. check_aligned_list(member, param_name, member_type)
  766. if member_have_list not in (None, True):
  767. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  768. param_name))
  769. if list_len is not None and len(member) != list_len:
  770. raise TypeError("The size of each member of parameter {0} is inconsistent".format(
  771. param_name))
  772. member_have_list = True
  773. list_len = len(member)
  774. else:
  775. type_check(member, (member_type,), param_name)
  776. if member_have_list not in (None, False):
  777. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  778. param_name))
  779. member_have_list = False
  780. def check_gnn_get_node_feature(method):
  781. """A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
  782. @wraps(method)
  783. def new_method(self, *args, **kwargs):
  784. [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  785. type_check(node_list, (list, np.ndarray), "node_list")
  786. if isinstance(node_list, list):
  787. check_aligned_list(node_list, 'node_list', int)
  788. elif isinstance(node_list, np.ndarray):
  789. if not node_list.dtype == np.int32:
  790. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  791. node_list, node_list.dtype))
  792. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  793. return method(self, *args, **kwargs)
  794. return new_method
  795. def check_gnn_get_edge_feature(method):
  796. """A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
  797. @wraps(method)
  798. def new_method(self, *args, **kwargs):
  799. [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  800. type_check(edge_list, (list, np.ndarray), "edge_list")
  801. if isinstance(edge_list, list):
  802. check_aligned_list(edge_list, 'edge_list', int)
  803. elif isinstance(edge_list, np.ndarray):
  804. if not edge_list.dtype == np.int32:
  805. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  806. edge_list, edge_list.dtype))
  807. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  808. return method(self, *args, **kwargs)
  809. return new_method
  810. def check_numpyslicesdataset(method):
  811. """A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
  812. @wraps(method)
  813. def new_method(self, *args, **kwargs):
  814. _, param_dict = parse_user_args(method, *args, **kwargs)
  815. data = param_dict.get("data")
  816. column_names = param_dict.get("column_names")
  817. if not data:
  818. raise ValueError("Argument data cannot be empty")
  819. type_check(data, (list, tuple, dict, np.ndarray), "data")
  820. if isinstance(data, tuple):
  821. type_check(data[0], (list, np.ndarray), "data[0]")
  822. # check column_names
  823. if column_names is not None:
  824. check_columns(column_names, "column_names")
  825. # check num of input column in column_names
  826. column_num = 1 if isinstance(column_names, str) else len(column_names)
  827. if isinstance(data, dict):
  828. data_column = len(list(data.keys()))
  829. if column_num != data_column:
  830. raise ValueError("Num of input column names is {0}, but required is {1}."
  831. .format(column_num, data_column))
  832. elif isinstance(data, tuple):
  833. if column_num != len(data):
  834. raise ValueError("Num of input column names is {0}, but required is {1}."
  835. .format(column_num, len(data)))
  836. else:
  837. if column_num != 1:
  838. raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
  839. .format(column_num, 1))
  840. return method(self, *args, **kwargs)
  841. return new_method
  842. def check_paddeddataset(method):
  843. """A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset)."""
  844. @wraps(method)
  845. def new_method(self, *args, **kwargs):
  846. _, param_dict = parse_user_args(method, *args, **kwargs)
  847. paddedSamples = param_dict.get("padded_samples")
  848. if not paddedSamples:
  849. raise ValueError("Argument padded_samples cannot be empty")
  850. type_check(paddedSamples, (list,), "padded_samples")
  851. type_check(paddedSamples[0], (dict,), "padded_element")
  852. return method(self, *args, **kwargs)
  853. return new_method